# Copyright (c) 2013-2020, SIB - Swiss Institute of Bioinformatics and
#                          Biozentrum - University of Basel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from promod3 import loop
from ost import conop
from ost.mol import mm
import numpy as np

# dihedral info conotains a list for every amino acid
# every list element is a tuple
# element 1: name of heavy atom
# element 2: tuple with three anchor atoms
# element 3: dihedral idx (0: chi1, 1: chi2, 2: chi3, 3: chi4, 4: 0.0)
# element 4: base_dihedral (float value that has to be added to the dihedral
#                           defined above to get final dihedral angle)

dihedral_info = dict()

#ARG
dihedral_info[conop.ARG] = list()
dihedral_info[conop.ARG].append(("CG",("N", "CA", "CB"), 0 , 0.0))
dihedral_info[conop.ARG].append(("CD",("CA", "CB", "CG"), 1 , 0.0))
dihedral_info[conop.ARG].append(("NE",("CB", "CG", "CD"), 2 , 0.0))
dihedral_info[conop.ARG].append(("CZ",("CG", "CD", "NE"), 3 , 0.0))
dihedral_info[conop.ARG].append(("NH1",("CD", "NE", "CZ"), 4 , np.pi))
dihedral_info[conop.ARG].append(("NH2",("CD", "NE", "CZ"), 4 , 0.0))


#ASN
dihedral_info[conop.ASN] = list()
dihedral_info[conop.ASN].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.ASN].append(("OD1", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.ASN].append(("ND2", ("CA", "CB", "CG"), 1, np.pi))

#P
dihedral_info[conop.ASP] = list()
dihedral_info[conop.ASP].append(("CG",("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.ASP].append(("OD1",("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.ASP].append(("OD2",("CA", "CB", "CG"), 1, np.pi))

#GLN
dihedral_info[conop.GLN] = list()
dihedral_info[conop.GLN].append(("CG",("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.GLN].append(("CD",("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.GLN].append(("OE1",("CB", "CG", "CD"), 2, 0.0))
dihedral_info[conop.GLN].append(("NE2",("CB", "CG", "CD"), 2, np.pi))

#GLU
dihedral_info[conop.GLU] = list()
dihedral_info[conop.GLU].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.GLU].append(("CD", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.GLU].append(("OE1", ("CB", "CG", "CD"), 2, 0.0))
dihedral_info[conop.GLU].append(("OE2", ("CB", "CG", "CD"), 2, np.pi))

#LYS
dihedral_info[conop.LYS] = list()
dihedral_info[conop.LYS].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.LYS].append(("CD", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.LYS].append(("CE", ("CB", "CG", "CD"), 2, 0.0))
dihedral_info[conop.LYS].append(("NZ", ("CG", "CD", "CE"), 3, 0.0))

#SER
dihedral_info[conop.SER] = list()
dihedral_info[conop.SER].append(("OG", ("N", "CA", "CB"), 0, 0.0))

#CYS
dihedral_info[conop.CYS] = list()
dihedral_info[conop.CYS].append(("SG", ("N", "CA", "CB"), 0, 0.0))

#MET
dihedral_info[conop.MET] = list()
dihedral_info[conop.MET].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.MET].append(("SD", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.MET].append(("CE", ("CB", "CG", "SD"), 2, 0.0))

#TRP
dihedral_info[conop.TRP] = list()
dihedral_info[conop.TRP].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.TRP].append(("CD1", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.TRP].append(("CD2", ("CA", "CB", "CG"), 1, np.pi))
dihedral_info[conop.TRP].append(("CE2", ("CD1", "CG", "CD2"), 4, 0.0))
dihedral_info[conop.TRP].append(("NE1", ("CG", "CD2", "CE2"), 4, 0.0))
dihedral_info[conop.TRP].append(("CE3", ("CD1", "CG", "CD2"), 4, np.pi))
dihedral_info[conop.TRP].append(("CZ3", ("CE2", "CD2", "CE3"), 4, 0.0))
dihedral_info[conop.TRP].append(("CH2", ("CD2", "CE3", "CZ3"), 4, 0.0))
dihedral_info[conop.TRP].append(("CZ2", ("CE3", "CZ3", "CH2"), 4, 0.0))

#TYR
dihedral_info[conop.TYR] = list()
dihedral_info[conop.TYR].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.TYR].append(("CD1", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.TYR].append(("CD2", ("CA", "CB", "CG"), 1, np.pi))
dihedral_info[conop.TYR].append(("CE1", ("CD2", "CG", "CD1"), 4, 0.0))
dihedral_info[conop.TYR].append(("CE2", ("CD1", "CG", "CD2"), 4, 0.0))
dihedral_info[conop.TYR].append(("CZ", ("CG", "CD1", "CE1"), 4, 0.0))
dihedral_info[conop.TYR].append(("OH", ("CD2", "CE2", "CZ"), 4, np.pi))

#THR
dihedral_info[conop.THR] = list()
dihedral_info[conop.THR].append(("OG1", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.THR].append(("CG2", ("OG1", "CA", "CB"), 4, -2.1665))

#VAL
dihedral_info[conop.VAL] = list()
dihedral_info[conop.VAL].append(("CG1", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.VAL].append(("CG2", ("CG1", "CA", "CB"), 4, 2.1640))

#ILE
dihedral_info[conop.ILE] = list()
dihedral_info[conop.ILE].append(("CG1", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.ILE].append(("CG2", ("CG1", "CA", "CB"), 4, -2.2696))
dihedral_info[conop.ILE].append(("CD1", ("CA", "CB", "CG1"), 1, 0.0))

#LEU
dihedral_info[conop.LEU] = list()
dihedral_info[conop.LEU].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.LEU].append(("CD1", ("CA", "CB", "CG"), 1, 0))
dihedral_info[conop.LEU].append(("CD2", ("CD1", "CB", "CG"), 4, 2.0944))

#PRO
dihedral_info[conop.PRO] = list()
dihedral_info[conop.PRO].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.PRO].append(("CD", ("CA", "CB", "CG"), 1, 0.0))

#HIS
dihedral_info[conop.HIS] = list()
dihedral_info[conop.HIS].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.HIS].append(("ND1", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.HIS].append(("CD2", ("CA", "CB", "CG"), 1, np.pi))
dihedral_info[conop.HIS].append(("CE1", ("CD2", "CG", "ND1"), 4, 0.0))
dihedral_info[conop.HIS].append(("NE2", ("ND1", "CG", "CD2"), 4, 0.0))

#PHE
dihedral_info[conop.PHE] = list()
dihedral_info[conop.PHE].append(("CG", ("N", "CA", "CB"), 0, 0.0))
dihedral_info[conop.PHE].append(("CD1", ("CA", "CB", "CG"), 1, 0.0))
dihedral_info[conop.PHE].append(("CD2", ("CA", "CB", "CG"), 1, np.pi))
dihedral_info[conop.PHE].append(("CE1", ("CD2", "CG", "CD1"), 4, 0.0))
dihedral_info[conop.PHE].append(("CE2", ("CD1", "CG", "CD2"), 4, 0.0))
dihedral_info[conop.PHE].append(("CZ", ("CG", "CD1", "CE1"), 4, 0.0))



# we know now all the stuff for the dihedrals but we still need to know
# ideal bond lengths and angles. We reed this stuff in from the internal
# coordinates info provided by the top_all36_prot.rtf file

file_content = open("top_all36_prot.rtf",'r').readlines()

ic_data = dict()

aa_names = list()
for aa in dihedral_info:
  aa_name = conop.AminoAcidToResidueName(aa)
  # hack to find the proper data in the rtf file...
  if aa_name == "HIS":
    aa_name = "HSE"
  aa_names.append(aa_name)
  ic_data[aa_name] = list()


in_interesting_section = False
current_aa = None

for line in file_content:
  split_line = line.split()
  if len(split_line) == 0:
    continue
  if split_line[0] == "RESI":
    if split_line[1] in aa_names:
      current_aa = split_line[1]
      in_interesting_section = True
    else:
      in_interesting_section = False
  if not in_interesting_section:
    continue

  if split_line[0] == "IC":
    ic_data[current_aa].append(split_line)

# the stuff still contains stars in front of the atom names,
# let's get rid of them
for aa in aa_names:
  for ic_list in ic_data[aa]:
    for i, item in enumerate(ic_list):
      ic_list[i] = item.replace('*','')
    

angles = dict()
bond_lengths = dict()

for aa in dihedral_info:
    angles[aa] = list()
    bond_lengths[aa] = list()


for aa in dihedral_info:

    aa_name = conop.AminoAcidToResidueName(aa)
    # hack to find his
    if aa_name == "HIS":
        aa_name = "HSE"

    aa_ic_data = ic_data[aa_name]

    for heavy_atom, anchor_atoms, dihedral_idx, base_dihedral in dihedral_info[aa]:

        # try to find the bond
        b = [anchor_atoms[2], heavy_atom]

        # hack to be consistent with CHARMM naming
        if aa_name == "ILE" and "CD1" in b:
          if b[0] == "CD1":
            b[0] = "CD"
          if b[1] == "CD1":
            b[1] = "CD"  

        bond_length = None
        for item in aa_ic_data:
          if b == item[3:5]:
            bond_length = float(item[-1])
            break

        # try to find the angle
        a = [anchor_atoms[1], anchor_atoms[2], heavy_atom]

        # hack to be consistent with CHARMM naming
        if aa_name == "ILE" and "CD1" in a:
          if a[0] == "CD1":
            a[0] = "CD"
          if a[1] == "CD1":
            a[1] = "CD" 
          if a[2] == "CD1":
            a[2] = "CD" 

        angle = None
        for item in aa_ic_data:
          if a == item[2:5]:
            angle = float(item[-2])/180*np.pi

        if bond_length == None:
          raise RuntimeError("Could not find internal coordinates for bond %s,%s in amino acid %s!"%(anchor_atoms[2],heavy_atom,aa_name))

        if angle == None:
          raise RuntimeError("Could not find internal coordinates for angle %s,%s,%s in amino acid %s!"%(anchor_atoms[1],anchor_atoms[2],heavy_atom,aa_name))

        angles[aa].append(angle)
        bond_lengths[aa].append(bond_length)


# Let's generate the code to add the rules to the HeavyAtomRuleLookup

for aa in dihedral_info:
    print(("  // " + conop.AminoAcidToResidueName(aa)))
    for i in range(len(dihedral_info[aa])):
        heavy_atom = dihedral_info[aa][i][0]
        anchor_one = dihedral_info[aa][i][1][0]
        anchor_two = dihedral_info[aa][i][1][1]
        anchor_three = dihedral_info[aa][i][1][2]
        dihedral_idx = dihedral_info[aa][i][2]
        base_dihedral = dihedral_info[aa][i][3]
        bond_length = bond_lengths[aa][i]
        angle = angles[aa][i]

        heavy_atom_idx = loop.AminoAcidLookup.GetIndex(aa, heavy_atom)
        anchor_one_idx = loop.AminoAcidLookup.GetIndex(aa, anchor_one)
        anchor_two_idx = loop.AminoAcidLookup.GetIndex(aa, anchor_two)
        anchor_three_idx = loop.AminoAcidLookup.GetIndex(aa, anchor_three)
  
        param = list()
        param.append("ost::conop::" + conop.AminoAcidToResidueName(aa))
        param.append(str(heavy_atom_idx))
        param.append(str(anchor_one_idx))
        param.append(str(anchor_two_idx))
        param.append(str(anchor_three_idx))
        param.append("Real("+str(bond_length)+')')
        param.append("Real("+str(angle)+')')
        param.append(str(dihedral_idx))
        param.append("Real("+str(base_dihedral)+')')

        print(("  AddRule(" + ", ".join(param[:5]) + ','))
        print(("          " + ", ".join(param[5:]) + ");"))


    print("")




#Once we have all info required to construct all heavy atoms in the sidechain
#we also want to know how the dihedrals are actually defined

dihedral_definitions = dict()


#ARG
dihedral_definitions[conop.ARG] = list()
dihedral_definitions[conop.ARG].append(("N","CA","CB","CG"))
dihedral_definitions[conop.ARG].append(("CA","CB","CG","CD"))
dihedral_definitions[conop.ARG].append(("CB","CG","CD","NE"))
dihedral_definitions[conop.ARG].append(("CG","CD","NE","CZ"))


#ASN
dihedral_definitions[conop.ASN] = list()
dihedral_definitions[conop.ASN].append(("N","CA","CB","CG"))
dihedral_definitions[conop.ASN].append(("CA","CB","CG","OD1"))

#ASP
dihedral_definitions[conop.ASP] = list()
dihedral_definitions[conop.ASP].append(("N","CA","CB","CG"))
dihedral_definitions[conop.ASP].append(("CA","CB","CG","OD1"))

#GLN
dihedral_definitions[conop.GLN] = list()
dihedral_definitions[conop.GLN].append(("N","CA","CB","CG"))
dihedral_definitions[conop.GLN].append(("CA","CB","CG","CD"))
dihedral_definitions[conop.GLN].append(("CB","CG","CD","OE1"))

#GLU
dihedral_definitions[conop.GLU] = list()
dihedral_definitions[conop.GLU].append(("N","CA","CB","CG"))
dihedral_definitions[conop.GLU].append(("CA","CB","CG","CD"))
dihedral_definitions[conop.GLU].append(("CB","CG","CD","OE1"))

#LYS
dihedral_definitions[conop.LYS] = list()
dihedral_definitions[conop.LYS].append(("N","CA","CB","CG"))
dihedral_definitions[conop.LYS].append(("CA","CB","CG","CD"))
dihedral_definitions[conop.LYS].append(("CB","CG","CD","CE"))
dihedral_definitions[conop.LYS].append(("CG","CD","CE","NZ"))

#SER
dihedral_definitions[conop.SER] = list()
dihedral_definitions[conop.SER].append(("N","CA","CB","OG"))

#CYS
dihedral_definitions[conop.CYS] = list()
dihedral_definitions[conop.CYS].append(("N","CA","CB","SG"))

#MET
dihedral_definitions[conop.MET] = list()
dihedral_definitions[conop.MET].append(("N","CA","CB","CG"))
dihedral_definitions[conop.MET].append(("CA","CB","CG","SD"))
dihedral_definitions[conop.MET].append(("CB","CG","SD","CE"))

#TRP
dihedral_definitions[conop.TRP] = list()
dihedral_definitions[conop.TRP].append(("N","CA","CB","CG"))
dihedral_definitions[conop.TRP].append(("CA","CB","CG","CD1"))

#TYR
dihedral_definitions[conop.TYR] = list()
dihedral_definitions[conop.TYR].append(("N","CA","CB","CG"))
dihedral_definitions[conop.TYR].append(("CA","CB","CG","CD1"))

#THR
dihedral_definitions[conop.THR] = list()
dihedral_definitions[conop.THR].append(("N","CA","CB","OG1"))

#VAL
dihedral_definitions[conop.VAL] = list()
dihedral_definitions[conop.VAL].append(("N","CA","CB","CG1"))

#ILE
dihedral_definitions[conop.ILE] = list()
dihedral_definitions[conop.ILE].append(("N","CA","CB","CG1"))
dihedral_definitions[conop.ILE].append(("CA","CB","CG1","CD1"))

#LEU
dihedral_definitions[conop.LEU] = list()
dihedral_definitions[conop.LEU].append(("N","CA","CB","CG"))
dihedral_definitions[conop.LEU].append(("CA","CB","CG","CD1"))

#PRO
dihedral_definitions[conop.PRO] = list()
dihedral_definitions[conop.PRO].append(("N","CA","CB","CG"))
dihedral_definitions[conop.PRO].append(("CA","CB","CG","CD"))

#HIS
dihedral_definitions[conop.HIS] = list()
dihedral_definitions[conop.HIS].append(("N","CA","CB","CG"))
dihedral_definitions[conop.HIS].append(("CA","CB","CG","ND1"))

#PHE
dihedral_definitions[conop.PHE] = list()
dihedral_definitions[conop.PHE].append(("N","CA","CB","CG"))
dihedral_definitions[conop.PHE].append(("CA","CB","CG","CD1"))


for aa in dihedral_definitions:
    print(("  // " + conop.AminoAcidToResidueName(aa)))
    for i in range(len(dihedral_definitions[aa])):
        atom_one = dihedral_definitions[aa][i][0]
        atom_two = dihedral_definitions[aa][i][1]
        atom_three = dihedral_definitions[aa][i][2]
        atom_four = dihedral_definitions[aa][i][3]

        idx_one = loop.AminoAcidLookup.GetIndex(aa, atom_one)
        idx_two = loop.AminoAcidLookup.GetIndex(aa, atom_two)
        idx_three = loop.AminoAcidLookup.GetIndex(aa, atom_three)
        idx_four = loop.AminoAcidLookup.GetIndex(aa, atom_four)

        param = list()
        param.append("ost::conop::" + conop.AminoAcidToResidueName(aa))
        param.append(str(idx_one))
        param.append(str(idx_two))
        param.append(str(idx_three))
        param.append(str(idx_four))

        print(("  AddChiDefinition(" + ", ".join(param) + ');'))


    print("")




