/*
* energy_nonPolar.cpp
*
*  Created on: Sep, 2017
*  Migrated from site module on Feb, 2018    
*  Author: Argo
*/

#include "energy.h"
#include "../misc/misc_overlapGaussian.h"
#include "../delphi/delphi_constants.h"

#include <set>
#include <map>
#include <complex>
#include <string>
#include <cmath>

#define MAXORDER 10
#define MAXWIDTH 45
#define MINVOL 1.0e-3 //0.001
#define SWITCH_A 25.
#define SIGMOID_G 5
#define LJCUTOFF_MAX 7   // Ang
#define LJCUTOFF_MIN 1.0   // Ang
//#define debug_energy

typedef std::complex<int> iComplex;
typedef std::set<int> iSet;
typedef std::vector<int> iTuples;

#ifdef debug_energy
static int iCall = 0;
#endif 

void CDelphiEnergy::initialize_Lambda_matrix() 
{
// cout << " Initializing Lambda Matrix..." ;
    for ( int i = 0; i <= iAtomNum; i++ ) 
        Lambda[i] = new delphi_real[iAtomNum+1]();

    for (int i = 1; i <= iAtomNum; i++ ) 
    {
        for (int j = 1; j <= iAtomNum; j++) 
            Lambda[i][j] = 0.;
    }

 } //END OF FUNCTION

void CDelphiEnergy::initialize_pdbAtoms_forGaussVol () 
{

    for ( int i = 1; i <= iAtomNum; i++ ) 
    {
        pdbAtoms[i].selfVolume = 0.;
        pdbAtoms[i].surfArea         = 0.;
        pdbAtoms[i].rel_surfArea     = 0.;
        pdbAtoms[i].Rvdw             = prgapAtomPdb[i-1].getRadius();
        pdbAtoms[i].ele              = prgapAtomPdb[i-1].getAtInf().substr(1,1);          //first character of atomname (atomname starts from 2nd coulumn)
        pdbAtoms[i].gamma            = prgapAtomPdb[i-1].getVdwGamma(); 
        pdbAtoms[i].isHydrogen       = (pdbAtoms[i].ele.compare("H") == 0); // True if hydrogen
        pdbAtoms[i].SA_multiplier    = 1;

        // initializing the sharedVolume map with A zero.
        for ( int kk = 1; kk <= MAXORDER; kk++ ) pdbAtoms[i].radii_derivative[kk] = 0.;
    }

} //END OF FUNCTION

void CDelphiEnergy::computeVolume_SA( vector<iComplex>& edges_v ) 
{

    cout << string(20,'-') << " NON-POLAR ENERGY CALCULATIONS " << string(20,'-') << endl;
    cout << left << setw(MAXWIDTH) << " Total number of atoms" << " : " << prgapAtomPdb.size() << endl;

    // first create a map 
    // that will store the adjecancy map
    // adjMap[atomid] = set of its neighbor ids
    // Will use the edge info in edges_v which are stored in the form of complex numbers
    // This map will be used to generate 3,4,5... order overlap sets
    // Correspondingly, work on gussianVol map

    // initialize
    // first for the root (dummy) node '0'
    adjMap[0] = iSet();
    atom_as_overlap[0] = OverlapRegion();

    //initialize the total volume,SA and Gcavity
    total_V = 0.;
    total_corrV = 0.;
    total_SA = 0.;
    
    // will create a associative map for each atom regardless of its isHydrogen value
    // will create a gaussianDensity object for that atom using its Rvdw, Roffset (Default = 0), coordinates 
    for ( int i=1; i<= iAtomNum; i++ ) 
    {

        if ( !pdbAtoms[i].isHydrogen) 
        {
            adjMap[0].insert(i);

            adjMap[i] = iSet();

            SGrid<delphi_real> coord_i = prgapAtomPdb[i-1].getPose();
            delphi_real Rvdw_i         = pdbAtoms[i].Rvdw;

            OverlapRegion OverlapRegion_i(KConstant, i, Rvdw_i, fRoffset, coord_i);
            atom_as_overlap[i] = OverlapRegion_i;
        }
    }


    vector< iComplex >::iterator eit = edges_v.begin();
    
    iSet overlappedAtoms;
    // int edges_read = 0;

    #ifdef debug_energy 
    cout << left << setw(MAXWIDTH) << " Total number of complex edges" << " : " << edges_v.size() << endl;
    #endif

    while ( eit != edges_v.end() ) 
    {
        // filtering atom pairs to check for those where one atom totally occupies the other (smaller) one
        // In those cases, we don't need the volume of the smaller atom in any order of overlap
        // because the larger will embody that in its own volume
        // Those atoms will have one entry only (-1).

        int index1 = (*eit).real();
        int index2 = (*eit).imag();
        
        if ( index1 == 0 || index2 == 0 ) 
        {
            eit++;
            continue;
        }

        if ( pdbAtoms[index1].isHydrogen || pdbAtoms[index2].isHydrogen ) 
        {
            eit++;
            continue;   
        }

        //foreach pair, get the centers and radius
        SGrid<delphi_real> cen1 = prgapAtomPdb[index1 - 1].getPose();
        SGrid<delphi_real> cen2 = prgapAtomPdb[index2 - 1].getPose();
        delphi_real dist2 = optDot(cen1 - cen2,cen1 - cen2);
        delphi_real R1   = prgapAtomPdb[index1 - 1].getRadius();
        delphi_real R2   = prgapAtomPdb[index2 - 1].getRadius();

        if ( (sqrt(dist2) + min(R2,R1)) < max(R1,R2)) 
        {
            //means a total overlap
            int smallerAtom = (R1 >= R2)?index2:index1;

            // erase the entry for smallerAtom in the adjMap
            adjMap.erase(smallerAtom);  

            //erase that atom from the adjMAp[0] because it is basically not contributing 
            // to any volume
            adjMap[0].erase(smallerAtom); 

            // log this entry
            overlappedAtoms.insert(smallerAtom);

            eit++;
            // edges_read++;
            continue;

        } 

        if (adjMap.count(index1)) 
        {
            
            adjMap[index1].insert(index2);
            
            // add an entry to the Lambda Matrix. i.e.
            if (Lambda[index1][index2] == 0.) 
                Lambda[index1][index2] = atom_as_overlap[index1].Alpha * atom_as_overlap[index2].Alpha * dist2; 

        }

        eit++;    
    } //eit

    cout << left << setw(MAXWIDTH) << " Number of atoms fully overlapped " << " : " << overlappedAtoms.size() << endl;
    cout << left << setw(MAXWIDTH) << " Number of non-hydrogen atoms " << " : " <<adjMap[0].size() << endl;
    #ifdef debug_energy
    cout << left << setw(MAXWIDTH) << " Number of edges read " << " : " << edges_read << endl;
    #endif

    // build a vector of tuplets that will store the poppeed elemets of tupletStack (the overlap nodes )
    // the degree of the overlap will exactly define the generation of the node
    vector < iTuples > Nodes;

    // build a vector of OverlapRegion objects that will store the volume of each tuple
    vector < OverlapRegion > vec_OverlapRegion; 


    // create the tree. 
    // the process is DF.
    for ( auto& atom_i: adjMap[0]) 
    {
        
        iTuples pair(1,0);
        pair.push_back(atom_i);

        OverlapRegion ov_i = atom_as_overlap[atom_i];
        pdbAtoms[atom_i].radii_derivative[1] += ov_i.Radii_derivative;

        iSet common_i = commonNeighbors(adjMap[0], adjMap[atom_i]);
        generate_tuples_DF(pair, ov_i, common_i, Nodes, vec_OverlapRegion);
    }


    #ifdef debug_energy
    // cout << "---------------------------------------------------------------" << endl;
    // printOverlapTree();
    // cout << "---------------------------------------------------------------" << endl;
    
    // cout << "---------------------------------------------------------------" << endl;
    // validateNeighbors();
    // cout << "---------------------------------------------------------------" << endl;
    // iSet common = commonNeighbors(adjMap[1], adjMap[2]);

    // cout << "Common : ";
    // for ( auto& e: common ) cout << e << " ";
    // cout << endl;
    // cout << "---------------------------------------------------------------" << endl;
    #endif

    #ifdef debug_energy
    vector< iTuples >::iterator nodes_itr = Nodes.begin();
    vector< OverlapRegion >::iterator overlap_itr = vec_OverlapRegion.begin();

    // checking for volumes of all orders of overlaps
    map<int, delphi_real> orderVol;

    for ( int ii = 1; ii <= MAXORDER; ii++ ) orderVol[ii] = 0.;

    while ( overlap_itr != vec_OverlapRegion.end() ) 
    {
        orderVol[(*overlap_itr).Order] += (*overlap_itr).Volume;
        overlap_itr++;
    }
    
    for ( int ii = 1; ii <= MAXORDER; ii++ ) 
    {
        cout << " Order " << ii << left << setw(MAXWIDTH) << " volume " << " : " << orderVol[ii] << " cu. Ang" << endl;
    }
    #endif

    cout << left << setw(MAXWIDTH) << " Number of Overlap objects found " << " : " << Nodes.size() << endl;


    #ifdef debug_energy
    // get_InterAtomic_Distances();
    #endif


} // END OF FUNCTION


void CDelphiEnergy::energy_cavity(delphi_real& fEnergy_CavitySA, delphi_real& fEnergy_CavityVol) 
{
    
    const delphi_real Rp = rgfProbeRadius[0];
    delphi_real total_SA2 = 0.;

    for ( auto& i: adjMap[0] ) {
        
        for ( int kk = 1; kk <= MAXORDER; kk++ ) {
            pdbAtoms[i].surfArea += pow(-1,kk+1)*pdbAtoms[i].radii_derivative[kk];
        }
        
        pdbAtoms[i].rel_surfArea = 1 - cos(2 * atan2(Rp, Rp + pdbAtoms[i].Rvdw + fRoffset)) ;
        pdbAtoms[i].rel_surfArea *= 0.5;

        // No change in SA
        // pdbAtoms[i].SA_multiplier = 1;
        
        // Modified Gallichio's filter with cutoff set at 0
        // pdbAtoms[i].SA_multiplier = (pdbAtoms[i].surfArea >= 0)? pow(pdbAtoms[i].surfArea,2)/(SWITCH_A + pow(pdbAtoms[i].surfArea,2)): 0.;

        // Sigmoid filter with cutoff set at RSA_i(prbrad)
        pdbAtoms[i].SA_multiplier = 1/(1 + exp(SIGMOID_G*(-pdbAtoms[i].surfArea + (pdbAtoms[i].surfArea * pdbAtoms[i].rel_surfArea ))));
        pdbAtoms[i].surfArea *= pdbAtoms[i].SA_multiplier;
        total_SA += pdbAtoms[i].surfArea;
        

        #ifdef debug_energy
        cout << " Atom " << i << " : " << right << setw(10) << pdbAtoms[i].surfArea << ", " << right << setw(10) << pdbAtoms[i].surfArea * pdbAtoms[i].rel_surfArea  << endl;
        #endif
        
        // cavity energy
        fEnergy_CavitySA += (pdbAtoms[i].surfArea * pdbAtoms[i].gamma);

        //volume correction after a non-zero offset
        delphi_real corrV_term = pdbAtoms[i].surfArea * (pdbAtoms[i].Rvdw + fRoffset)/3;
        corrV_term *= (1 - pow(pdbAtoms[i].Rvdw/(pdbAtoms[i].Rvdw + fRoffset),3));

        total_corrV += corrV_term;

    }
    
    cout << enerString << left << setw(MAXWIDTH) << " Total molecular volume " << " : " << total_V << " cu. Ang" << endl;
    cout << enerString << left << setw(MAXWIDTH) << " Total molecular volume corrected for SEV" << " : " << total_V - total_corrV << " cu. Ang" << endl;
    cout << enerString << left << setw(MAXWIDTH) << " Total molecular surface area " << " : " << total_SA << " sq. Ang" << endl; 
    
    // multiply by solvent pressure coeff to the total corrected volume term
    fEnergy_CavityVol = fPressureCoeff * (total_V - total_corrV);
    fEnergy_CavityVol *= fTemper/298.0;

    // gamma values are in kT with T = 298K. So an additional operation
    // ensures that enrgy in kT is compatible with the input temperature
    fEnergy_CavitySA *= fTemper/298.0;
    
    return;
}//END OF FUNCTION


void CDelphiEnergy::generate_tuples_DF (iTuples parent_tuple, 
                                        OverlapRegion& parent_ov, iSet& commonNb, 
                                        vector < iTuples > & nodes, 
                                        vector<OverlapRegion>& v_OR) 
{

    #ifdef debug_energy
    cout << " ENERGY - generate_tuples_DF call # : " << iCall++ << endl;
    #endif
    
    //recursive
    if ((parent_tuple.size() <= MAXORDER) && (parent_ov.Volume > MINVOL)) {
        nodes.push_back(parent_tuple);
        v_OR.push_back(parent_ov);
        total_V += (pow(-1,parent_ov.Order + 1)*parent_ov.Volume);
        // total_SA += (pow(-1,parent_ov.Order + 1)*parent_ov.Radii_derivative);

    } else {
        return;
    }

    if ( commonNb.empty() || (parent_ov.Volume <= MINVOL) ) return; 

    for ( auto& common_atom: commonNb ) {
        iTuples child_tuple(parent_tuple);
        child_tuple.push_back(common_atom);

        OverlapRegion child_ov = calculateOverlap(parent_ov, atom_as_overlap[common_atom]);
        iSet commonNb_new = commonNeighbors(commonNb, adjMap[common_atom]);
        generate_tuples_DF(child_tuple, child_ov, commonNb_new, nodes, v_OR);
    }    
}//end of function


iSet CDelphiEnergy::commonNeighbors (iSet& nlist1, iSet& nlist2) 
{

    size_t len1 = nlist1.size();
    size_t len2 = nlist2.size();

    if ( len1 * len2 == 0 ) return iSet();    //if one of them is empty

    vector<int> v_common_nb;
    iSet common_nb;

    iSet::iterator it1 = nlist1.begin();
    iSet::iterator it2 = nlist2.begin();

    while ( it1 != nlist1.end() && it2 != nlist2.end() ) 
    {
        if ( *it1 < *it2 ) {
            ++it1;
        } else {
            if ( *it1 == *it2 ) {
                common_nb.insert(*it1++);
            }
            ++it2;
        }
    }


    return common_nb;
} //END OF FUNCTION


OverlapRegion CDelphiEnergy::calculateOverlap(OverlapRegion& parent_ov, OverlapRegion& ov_c ) 
{

    OverlapRegion ov_new;
    int c = *(ov_c.Atoms.begin());

    ov_new.Atoms = parent_ov.Atoms;
    ov_new.Atoms.push_back(c);

    // cout << " Parent Tuple : ";
    // printTuple(parent_ov.Atoms);
    // cout << "  +   ";
    // printTuple(ov_c.Atoms) ;
    // cout << " ---> " ;
    // printTuple(ov_new.Atoms);
    // cout << endl;

    ov_new.Order = parent_ov.Order + ov_c.Order;
    ov_new.Alpha = parent_ov.Alpha + ov_c.Alpha;
    ov_new.P     = parent_ov.P * ov_c.P;

    // cout << "A = " <<ov_new.Alpha << endl;
    // cout << "P = " <<ov_new.P << endl;

    //coalescence center
    ov_new.r     = (parent_ov.Alpha * parent_ov.r) + (ov_c.Alpha * ov_c.r);
    ov_new.r     = ov_new.r/ov_new.Alpha;


    // SumLambda
    ov_new.SumLambda = parent_ov.SumLambda;

    iTuples::iterator it = parent_ov.Atoms.begin();
    while ( it != parent_ov.Atoms.end() ) {
        ov_new.SumLambda += Lambda[*it][c];
        it++;
    }

    // cout << "sLambda = " << ov_new.SumLambda << endl;

    //volume
    ov_new.Volume = ov_new.P * exp(-(ov_new.SumLambda)/ov_new.Alpha) * pow((fPi/ov_new.Alpha),1.5);


    iTuples::iterator iter = ov_new.Atoms.begin();
    while ( iter != ov_new.Atoms.end() ) {
        if (ov_new.Volume <= MINVOL || ov_new.Order > MAXORDER ) break;
        SGrid<delphi_real> r_iter = atom_as_overlap[*iter].r - ov_new.r;
        delphi_real        A_iter = 2 * KConstant * ov_new.Volume * ( (1.5/ov_new.Alpha) + optDot(r_iter, r_iter))/pow(pdbAtoms[*iter].Rvdw + fRoffset,3);
        pdbAtoms[*iter].radii_derivative[ov_new.Order] += A_iter;
        ov_new.Radii_derivative += A_iter;
        iter++;
    }


    return ov_new;

}


void CDelphiEnergy::printTuple (iTuples& tup) 
{
    cout << string(tup.size(),'>') << " ";
    for ( auto& e: tup ) cout << e << " ";
    // cout << endl;    
}

void CDelphiEnergy::printOverlapTree() 
{

    for (int i=0; i<= iAtomNum; i++ ) 
    {
        iSet::iterator local_itr = adjMap[i].begin();
        cout << " Atom " << i << " : { ";
        
        while ( local_itr != adjMap[i].end() ) 
        {
            cout << *local_itr << " ";
            local_itr++;
        }
        cout << "} " << endl;

    }
} // end of function

void CDelphiEnergy::validateNeighbors()
{

    iSet::iterator iter = adjMap[0].begin();
    int atom_i = -1, atom_j = -1;

    while (iter != adjMap[0].end())
    {
        atom_i = *iter;
        SGrid<delphi_real> cen_i = prgapAtomPdb[atom_i - 1].getPose();
        delphi_real R1   = prgapAtomPdb[atom_i - 1].getRadius();

        iSet::iterator jter = iter;
        jter++;

        while (jter != adjMap[0].end())
        {
        atom_j = *jter;

            cout << "NEIGH>" <<  setw(6) << right << atom_i << "  " << setw(6) << right << atom_j << " ";
            if (adjMap[atom_i].find(atom_j) != adjMap[atom_i].end())
                cout << "T" << " ";
            else
                cout << "F" << " ";
            
            //foreach pair, get the centers and radius
            SGrid<delphi_real> cen_j = prgapAtomPdb[atom_j - 1].getPose();
            delphi_real dist2 = optDot(cen_i - cen_j,cen_i - cen_j);
            delphi_real R2   = prgapAtomPdb[atom_j - 1].getRadius();

            if ( sqrt(dist2) <= (R1 + R2) )
                cout << "T" << " ";
            else
                cout << "F" << " ";

            cout << setw(10) << right << sqrt(dist2) << " " << setw(10) << right << (R1 + R2) << endl;

            jter++;
        }
        iter++;
    }// iter

} // end of function


void CDelphiEnergy::get_InterAtomic_Distances() 
{

    cout << " Writing Interatomic distances for overlap pairs." << endl;
    //takes the vector of edges and puts them in to a set.
    // expensive operation, so it was not done earlier and
    // is only being done when asked for.
    set<BiGrid> bSet;
    for ( auto& e: edges_v ) {
        BiGrid bg(e.real(), e.imag());
        bSet.insert(bg);
    }

    cout << left << setw(MAXWIDTH) << " Number of unique overlap atom pairs " << " : " << bSet.size() << endl;
    for ( auto& b: bSet ) {
        int idx1 = b.aidx1;
        int idx2 = b.aidx2;

        if ( idx1 < 0 || idx2 < 0 || (idx1 == idx2 )) continue;

        SGrid<delphi_real> cen1 = prgapAtomPdb[idx1 - 1].getPose();
        SGrid<delphi_real> cen2 = prgapAtomPdb[idx2 - 1].getPose();
        delphi_real dist2 = optDot(cen1 - cen2,cen1 - cen2);
        string atomInf1   = prgapAtomPdb[idx1 - 1].getAtInf();
        string atomInf2   = prgapAtomPdb[idx2 - 1].getAtInf();
        
        string resname1   = removeSpace(atomInf1.substr(6,3));
        string resid1     = removeSpace(atomInf1.substr(11,4));
        
        string resname2   = removeSpace(atomInf2.substr(6,3));
        string resid2     = removeSpace(atomInf2.substr(11,4));
        
        delphi_real q1    = prgapAtomPdb[idx1 - 1].getCharge();
        delphi_real q2    = prgapAtomPdb[idx2 - 1].getCharge();
        
        cout << idx1 << "\t" << resname1 << "\t" << resid1 << "\t";
        cout << q1 << "\t";
        cout << idx2 << "\t" << resname2 << "\t" << resid2 << "\t";
        cout << q2 << "\t";
        cout << sqrt(dist2) << endl;
        

    }

    return;
} // END OF FUNCTION


delphi_real CDelphiEnergy::geometricMean(delphi_real& f_i, delphi_real& f_j)
{
    return sqrt(f_i*f_j);
}// END OF FUNCTION



delphi_real CDelphiEnergy::arithmaticMean(delphi_real& f_i, delphi_real& f_j)
{
    return 0.5*(f_i + f_j);
}// END OF FUNCTION

string CDelphiEnergy::removeSpace(const string& strLine) 
{
    string strNewLine;

    for (size_t i = 0; i < strLine.size(); i++)
           if (' ' != strLine[i]) strNewLine += strLine[i];
           
    return strNewLine;
}// END OF FUNCTION



