#!/usr/bin/env python3

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import sys

# parameters
l = 1 # bond length
a = 1 # monomer size
theta = 68*(np.pi/180) # bond angle
Nb = 100 # max Degree of polymerization
polymer_model = 'FJC' # 'FRC'
Nc = 1000 # number of polymer chains to create and save
N_trial = 10 # number of tries for Rosenbluth method
seed = None #1234

# initalize random number generator
rng = np.random.default_rng()
ss = np.random.SeedSequence(seed)

# some useful parameters
a2 = a*a # square of monomer size
infty = np.finfo('double').max # "infinity" (for double precision)

print("*********************************")
print('Running Rosenbluth MC Simulation')
print('*********************************')
print()

print('Parameters: ')
print()
print('  * bond length, l = ', l)
print('  * monomer size, a = ', a)
print('  * bond angle, theta = ', theta)
print('     - cos(theta) = ', np.cos(theta))
print('  * degree of polymerization, Nb = ', Nb)
print('  * polymer model = ', polymer_model)
print('  * random seed = {}'.format(ss.entropy))
print('  * N_trial = {}'.format(N_trial))

# randomly oriented unit vector (FJC)
def r_FJC():
# {{{

  u, v = np.random.random(2)
  theta = np.arccos(2*u-1)  # pick random bond angle
  phi = 2*np.pi*v           # pick random torsion angle

  x = l*np.cos(theta)
  y = l*np.sin(theta)*np.cos(phi)
  z = l*np.sin(theta)*np.sin(phi)

  return np.array([x, y, z])

# }}}

def r_FRC(b_old):
# {{{

  v = np.random.random()
  phi = 2*np.pi*v

  x = l*np.cos(theta)
  y = l*np.sin(theta)*np.cos(phi)
  z = l*np.sin(theta)*np.sin(phi)

  b_new = np.array([x, y, z])

  # Get rotation matrix to put new bond in the same
  # global frame of reference with the old bond

  cos_theta_old = b_old[0]/l
  sin_theta_old = np.sqrt(1-cos_theta_old*cos_theta_old)
  # This doesn't have to have the negative branch,
  # because theta only rages from 0, pi.
  # Phi is the one with 2*pi domain.

  if ( np.abs(cos_theta_old) < (1.-1e-12) ):

    cos_phi_old = b_old[1]/sin_theta_old/l
    sin_phi_old = b_old[2]/sin_theta_old/l

    Rot_Mat_theta = np.array([[ cos_theta_old, -sin_theta_old,  0.],
                              [ sin_theta_old,  cos_theta_old,  0.],
                              [            0.,             0.,  1.]])

    Rot_Mat_phi = np.array([[ 1.,          0.,           0.],
                            [ 0., cos_phi_old, -sin_phi_old],
                            [ 0., sin_phi_old,  cos_phi_old]])

    Rot_Mat_tot = np.dot(Rot_Mat_phi, Rot_Mat_theta)
    r_new = np.dot(Rot_Mat_tot, b_new)
    #r_new = np.dot(Rot_Mat_phi, np.dot(Rot_Mat_theta, b_new))

  elif (cos_theta_old < 0.): 
  # This should only happen when cos_theta_old == -1.
  # In this case, I should go exactly backwards.

    r_new = np.zeros(3)
    r_new[0] = -b_new[0]
    r_new[1] = b_new[1]
    r_new[2] = b_new[2]

  else: # if cos_theta_old > 0, then is exactly 1.

    r_new = b_new

  return r_new
# }}}

def expBU(i, R, r):
# {{{
  # checks for overlap
  # returns exp(-Beta * U(r_ij))

  val = 1. # U = 0
  for j in range(0, i-1): # skip nearest neighbor and self

    Rij = R[j] - r
    Rij2 = np.dot(Rij, Rij)

    if Rij2 < a2:
      val = 0. # Overlap: U = infty
      break

  return val
# }}}

print()
print('Creating and saving Nc = %d chains'%Nc)
print()

# array of weights
w = np.zeros(N_trial)
R_trial = np.zeros((N_trial, 3))

n=0
while n < Nc:

  print('  * Creating chain: %d'%n)
  killchain = False

  # chain positions
  R = np.zeros((Nb, 3))
  R[0, :] = np.zeros(3)

  # chain weight
  ln_Wn = np.zeros(Nb)
  ln_Wn[0] = 0.

  # Grow chain one monomer at a time
  for i in range(1, Nb):

    if i == 1:

      r_new = r_FJC() # np.array([1, 0, 0])
      R[i, :] = R[i-1, :] + r_new
      ln_Wn[i] = 0.

    else:

      for j in range(N_trial):

        if (polymer_model == 'FJC'):
          r_new = r_FJC()
        elif (polymer_model == 'FRC'):
          b_old = R[i-1, :] - R[i-2, :]
          r_new = r_FRC(b_old)
        else:
          print('error: select proper polymer model')

        R_trial[j] = R[i-1, :] + r_new
        w[j] = expBU(i, R, R_trial[j])

      Wtot = np.sum(w) # total weight
      if (Wtot < 1e-12 and killchain == False):

        print('    - No possible placement at bead %d. Killing chain.'%i)
        killchain = True

      if (killchain == True):

        #print('chain cant grow at bead %d'%i) 
        R[i, :] = np.zeros(3)
        ln_Wn[i] = -infty # will force an underflow eventually when adding up weights

      else:

        P = w/Wtot # probability of a move
        Pcum = np.cumsum(P) # cumulative probability

        # choose one of the moves
        Pselect = np.random.random()
        Pdiff = Pcum - Pselect
        j_select = np.where(Pdiff > 0, Pdiff, np.inf).argmin()
        R_select = R_trial[j_select]

        R[i, :] = R_select
        ln_Wn[i] = ln_Wn[i-1] + np.log(Wtot/N_trial)

  # save chain to file for future analysis
  print('  * Saving chain: {}'.format(n))
  np.savetxt('R_chain_%05d.dat'%n, R, header = 'Rx, Ry, Rz')
  np.savetxt('ln_W_%05d.dat'%n, ln_Wn, header = 'ln_W(n)')
  n += 1

######################
### Error Checking ###
######################
# {{{
#
#print()
#print('Error Checking: Bond lengths and angles')
#print()
#
## check bond lengths and angles
#print("-"*33)
#print('%10s %10s %10s'%('i', 'len', 'cos_theta'))
#print("-"*33)
#for i in range(1, Nb):
#
#  b_i = R[i] - R[i-1]
#  cos_theta = 0
#
#  if (i>1):
#    b_im1 = R[i-1] - R[i-2]
#    cos_theta = np.dot(b_i, b_im1)
#
#  print('%10d %10f %10f'%(i, np.linalg.norm(b_i), cos_theta))
#print("-"*33)
# }}}

############
### Plot ###
############
# {{{

print()
print('Plotting the last chain')
print()

# Center about center of mass of chain
Rcm = np.mean(R, axis=0)
Rplt = R - Rcm

# for plotting, get (x, y, z) points of monomers
X = Rplt[:, 0]
Y = Rplt[:, 1]
Z = Rplt[:, 2]

fig = plt.figure()
ax = plt.axes(projection='3d')

ax.scatter3D(X, Y, Z)
ax.plot3D(X, Y, Z)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# get max extent for axis limits
Lmax = np.ceil(np.amax(R))

ax.set_xlim([-Lmax, Lmax])
ax.set_ylim([-Lmax, Lmax])
ax.set_zlim([-Lmax, Lmax])

plt.show()
plt.close('all')
# }}}

