/*
* energy_lj.cpp
*
*  Created on: Apr, 2019
*  Author: Argo
*/

#include "energy.h"

#include <string>
#include <cmath>

#define LJCUTOFF_MAX 7   // Ang
#define LJCUTOFF_MIN 1.0   // Ang
//#define debug_energy

void CDelphiEnergy::energy_lj(delphi_real& fEnergy_LJ) 
{
    delphi_real fAtt;
    delphi_real fRep;
    delphi_real fDiff;
    delphi_real iSigma;
    delphi_real jSigma;
    delphi_real iEps;
    delphi_real jEps;
    delphi_real ijEps;
    delphi_real ijSigma;
    delphi_real dist2;
    delphi_real dist6;
    delphi_real dist12;
    SGrid<delphi_real> iCoord;
    SGrid<delphi_real> jCoord;
    SGrid<delphi_real> ijCoord;
    delphi_integer numLJPairs = 0;

    for ( int idx = 1; idx < iAtomNum; idx++ ) 
    {
        iCoord = prgapAtomPdb[idx - 1].getPose();
        iSigma = prgapAtomPdb[idx - 1].getSigmaLJ();
        iEps   = prgapAtomPdb[idx - 1].getEpsilonLJ();

        for (int jdx = idx + 1; jdx <= iAtomNum; jdx++ )
        {
            jCoord = prgapAtomPdb[jdx - 1].getPose();
            jSigma = prgapAtomPdb[jdx - 1].getSigmaLJ();
            jEps   = prgapAtomPdb[jdx - 1].getEpsilonLJ();
            
            // distance squared (taking absolute because it will help create filter)
            // and not affect the distance.
            ijCoord = optABS(iCoord - jCoord);

            //  filter pairs that father than 7 Ang (arbitrarily chosen cutoff)
            if ( ijCoord.nX > LJCUTOFF_MAX || ijCoord.nX < LJCUTOFF_MIN  || 
                 ijCoord.nY > LJCUTOFF_MAX || ijCoord.nY < LJCUTOFF_MIN  || 
                 ijCoord.nZ > LJCUTOFF_MAX || ijCoord.nZ < LJCUTOFF_MIN  )
                continue;

            dist2 = optDot(ijCoord, ijCoord);
            
            // distance 6th and 12th powers
            dist6 = dist2 * dist2 * dist2;
            dist12 = dist6 * dist6;
        
            // combination sigma/epsilon
            ijEps = geometricMean(iEps, jEps);
            ijSigma = arithmaticMean(iSigma, jSigma); 

            // attractive term (6th power)
            fAtt = pow(ijSigma,6)/dist6;

            // repulsive term
            fRep = pow(ijSigma,12)/dist12;

            // difference of the terms multiplied by 4*eps_ij
            fDiff = 4 * ijEps * (fRep - fAtt);
            // cout << "energy_nonpol> " << idx << ", " << jdx << " : " << fDiff << endl;
    
            // add the pair's contribution to the total energy
            fEnergy_LJ += fDiff; 

            numLJPairs++;


        } // j : i+1 -> iAtomNum

    } // i: 1 -> iAtomNum-1

    // epsilon values are in kT with T = 298K. So an additional operation
    // ensures that enrgy in kT is compatible with the input temperature
    fEnergy_LJ *= fTemper/298.0;
    
    /* ---------------  COMMENTED OUT ---------------------//
    delphi_real fLJ = 0.;
    delphi_real dist2, dist6, dist12;
    delphi_real A_ij = 0., B_ij = 0.;

    set<BiGrid> bSet;
    for ( auto& e: edges_v ) 
    {
        BiGrid bg(e.real(), e.imag());
        bSet.insert(bg);
    }

    for ( auto& b: bSet ) 
    {
        int idx1 = b.aidx1;
        int idx2 = b.aidx2;
        
        // reset
        A_ij = 0.; B_ij = 0.;
        dist2 = 0.; dist6 = 0.; dist12 = 0.;

        if ( idx1 < 0 || idx2 < 0 || (idx1 == idx2 )) continue;

        SGrid<delphi_real> cen1 = prgapAtomPdb[idx1 - 1].getPose();
        SGrid<delphi_real> cen2 = prgapAtomPdb[idx2 - 1].getPose();
        
        dist2 = optDot(cen1 - cen2,cen1 - cen2);
        dist6 = pow(dist2,3);
        dist12 = pow(dist6,2);
        
        delphi_real Rmin_i    = prgapAtomPdb[idx1 - 1].getSigmaLJ();
        delphi_real Rmin_j    = prgapAtomPdb[idx2 - 1].getSigmaLJ();
        
        delphi_real epsilon_i    = prgapAtomPdb[idx1 - 1].getEpsilonLJ();
        delphi_real epsilon_j    = prgapAtomPdb[idx2 - 1].getEpsilonLJ();

        //pair A_ij and B_ij (for Amber parameters)
        A_ij =     geometricMean(epsilon_i, epsilon_j);
        A_ij *= pow(arithmaticMean(Rmin_i, Rmin_j), 12);

        B_ij = 2 * geometricMean(epsilon_i, epsilon_j);
        B_ij *= pow(arithmaticMean(Rmin_i, Rmin_j), 6);

        // cout << " " << (A_ij/dist12) << " " << -1*(B_ij/dist6) << endl;
        fLJ += (A_ij/dist12);
        fLJ -= (B_ij/dist6);
    }
    **/

}//END OF FUNCTION
