#!/usr/bin/env python3

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

# parameters
l = 1 # bond length
theta = 68*(np.pi/180) # bond angle
Nb = 1000 # max Degree of polymerization
polymer_model = 'FRC' # 'FJC'
Nc = 500 # number of polymer chains to create and save
seed = None #1234

# initalize random number generator
rng = np.random.default_rng()
ss = np.random.SeedSequence(seed)

print("*********************************")
print('Running Rosenbluth MC Simulation')
print('*********************************')
print()

print('Parameters: ')
print()
print('  * bond length, l = ', l)
print('  * bond angle, theta = ', theta)
print('     - cos(theta) = ', np.cos(theta))
print('  * degree of polymerization, Nb = ', Nb)
print('  * polymer model = ', polymer_model)
print('  * random seed = {}'.format(ss.entropy))

R = np.zeros((Nb, 3))
R[0, :] = np.zeros(3)

# randomly oriented unit vector (FJC)
def r_FJC():
# {{{

  u, v = np.random.random(2)
  theta = np.arccos(2*u-1)  # pick random bond angle
  phi = 2*np.pi*v           # pick random torsion angle

  x = l*np.cos(theta)
  y = l*np.sin(theta)*np.cos(phi)
  z = l*np.sin(theta)*np.sin(phi)

  return np.array([x, y, z])

# }}}

def r_FRC(b_old):
# {{{

  v = np.random.random()
  phi = 2*np.pi*v

  x = l*np.cos(theta)
  y = l*np.sin(theta)*np.cos(phi)
  z = l*np.sin(theta)*np.sin(phi)

  b_new = np.array([x, y, z])

  # Get rotation matrix to put new bond in the same
  # global frame of reference with the old bond

  cos_theta_old = b_old[0]/l
  sin_theta_old = np.sqrt(1-cos_theta_old*cos_theta_old)
  # This doesn't have to have the negative branch,
  # because theta only rages from 0, pi.
  # Phi is the one with 2*pi domain.

  if ( np.abs(cos_theta_old) < (1.-1e-12) ):

    cos_phi_old = b_old[1]/sin_theta_old/l
    sin_phi_old = b_old[2]/sin_theta_old/l

    Rot_Mat_theta = np.array([[ cos_theta_old, -sin_theta_old,  0.],
                              [ sin_theta_old,  cos_theta_old,  0.],
                              [            0.,             0.,  1.]])

    Rot_Mat_phi = np.array([[ 1.,          0.,           0.],
                            [ 0., cos_phi_old, -sin_phi_old],
                            [ 0., sin_phi_old,  cos_phi_old]])

    Rot_Mat_tot = np.dot(Rot_Mat_phi, Rot_Mat_theta)
    r_new = np.dot(Rot_Mat_tot, b_new)
    #r_new = np.dot(Rot_Mat_phi, np.dot(Rot_Mat_theta, b_new))

  elif (cos_theta_old < 0.): 
  # This should only happen when cos_theta_old == -1.
  # In this case, I should go exactly backwards.

    r_new = np.zeros(3)
    r_new[0] = -b_new[0]
    r_new[1] = b_new[1]
    r_new[2] = b_new[2]

  else: # if cos_theta_old > 0, then is exactly 1.

    r_new = b_new

  return r_new
# }}}

print()
print('Creating and saving Nc = %d chains'%Nc)
print()

for n in range(Nc):

  print('  * Creating chain: %d'%n)

  # Grow chain one monomer at a time
  for i in range(1, Nb):

    if i == 1:

      r_new = np.array([1, 0, 0]) #r_FJC()

    else:

      if (polymer_model == 'FJC'):
        r_new = r_FJC()
      elif (polymer_model == 'FRC'):
        b_old = R[i-1, :] - R[i-2, :]
        r_new = r_FRC(b_old)
      else:
        print('error: select proper polymer model')

    R[i, :] = R[i-1, :] + r_new

  # save chain to file for future analysis
  print('  * Saving chain ...')
  np.savetxt('R_chain_%05d.dat'%n, R)

######################
### Error Checking ###
######################
# {{{

print()
print('Error Checking: Bond lengths and angles')
print()

# check bond lengths and angles
print("-"*33)
print('%10s %10s %10s'%('i', 'len', 'cos_theta'))
print("-"*33)
for i in range(1, Nb):

  b_i = R[i] - R[i-1]
  cos_theta = 0

  if (i>1):
    b_im1 = R[i-1] - R[i-2]
    cos_theta = np.dot(b_i, b_im1)

  print('%10d %10f %10f'%(i, np.linalg.norm(b_i), cos_theta))
print("-"*33)
# }}}

############
### Plot ###
############
# {{{

print()
print('Plotting the last chain')
print()

# Center about center of mass of chain
Rcm = np.mean(R, axis=0)
R -= Rcm

# for plotting, get (x, y, z) points of monomers
X = R[:, 0]
Y = R[:, 1]
Z = R[:, 2]

fig = plt.figure()
ax = plt.axes(projection='3d')

ax.scatter3D(X, Y, Z)
ax.plot3D(X, Y, Z)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# get max extent for axis limits
Lmax = np.ceil(np.amax(R))

ax.set_xlim([-Lmax, Lmax])
ax.set_ylim([-Lmax, Lmax])
ax.set_zlim([-Lmax, Lmax])

plt.show()
plt.close('all')
# }}}

