package org.baderlab.pdzsvm.predictor.additive;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.File;
import java.util.*;
import java.util.List;

import org.baderlab.brain.ProteinProfile;
import org.biojava.bio.seq.Sequence;
import weka.core.Instances;
import org.baderlab.pdzsvm.data.*;
import org.baderlab.pdzsvm.data.manager.DataFileManager;
import org.baderlab.pdzsvm.evaluation.Prediction;
import org.baderlab.pdzsvm.evaluation.Evaluation;

import org.baderlab.pdzsvm.encoding.Chen16FeatureEncoding;
import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.validation.ValidationParameters;
import org.baderlab.pdzsvm.utils.Constants;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Chen's additive mouse predictor.
 * This predictor does not do any training on data, it sums the appropriate
 * parameter values according to the tutorial published in the original paper.
 * Chen, J. et al. (2008) Predicting PDZ domain-peptide interactions from
 *   primary sequences, Nat. Biotechnol., 26, 1041-1045.
 */
public class AdditivePredictor extends Predictor
{
    private double TAU =  0.3978;
    //private double TAU = 1.0978;
    final String alphabet = "GAVLIMPFWSTNQYCKRHDE";
    private HashMap domainPeptidePosPairToAdditiveModelMap;
    private List predictionList;
    private Chen16FeatureEncoding enc;

    private class AdditiveModelValues
    {
        double[][] additiveModelMatrix = new double[20][20];

        public AdditiveModelValues()
        {

            for (int i=0;i < 20;i++)
            {
                for (int j=0;j < 20;j++)
                {
                    additiveModelMatrix[i][j] = 0.0;
                }
            }
        }


    }

    public void printMatrix()
    {
        Set keys = domainPeptidePosPairToAdditiveModelMap.keySet();
        List keyList = new ArrayList(keys);
        Collections.sort(keyList);
        for (int i=0; i < keyList.size();i++)
        {
            String keyString = (String)keyList.get(i);
            AdditiveModelValues values = (AdditiveModelValues)domainPeptidePosPairToAdditiveModelMap.get(keyString);
            double[][] valueMatrix = values.additiveModelMatrix;
            String[] splitString = keyString.split("-");
            int domainPos = Integer.parseInt(splitString[0])+1;
            int peptidePos = -1*Integer.parseInt(splitString[1]);
            for (int j = 0; j < valueMatrix.length;j++)
            {
                char domainChar = alphabet.charAt(j);
                for (int k = 0; k < valueMatrix[j].length;k++)
                {
                    char peptideChar = alphabet.charAt(k);
                    double theta = valueMatrix[j][k];
                    if (theta == 0.0)
                    {
                        System.out.println(domainPos + "\t" + peptidePos +"\t" + domainChar + "\t"+peptideChar + "\t0");
                    }
                    else
                        System.out.println(domainPos + "\t" + peptidePos +"\t" + domainChar + "\t"+peptideChar + "\t" + theta);

                }
            }
        }

    }
    public AdditivePredictor(List posTrainProfileList, List negTrainProfileList)
    {
        super(posTrainProfileList, negTrainProfileList);
        predictionList = new ArrayList();
        domainPeptidePosPairToAdditiveModelMap= new HashMap();
        String paramFile = DataFileManager.DATA_ROOT_DIR+"/Data/Chen/Model/ChenPredictorModel-Binary.txt";
        try
        {
            BufferedReader br = new BufferedReader(new FileReader(new File(paramFile)));
            String line= "";
            while((line=br.readLine())!=null)
            {
                String[] splitString = line.split("\\s+");
                int domainPos = Integer.parseInt(splitString[0])-1;
                int peptidePos = -1*Integer.parseInt(splitString[1]);
                String domainRes = splitString[2];
                String peptideRes = splitString[3];
                int domainResIx = alphabet.indexOf(domainRes);
                int peptideResIx = alphabet.indexOf(peptideRes);
                double value = Double.parseDouble(splitString[4]);

                String str = Integer.toString(domainPos)  +"-" +Integer.toString(peptidePos);

                AdditiveModelValues modelValues = (AdditiveModelValues)domainPeptidePosPairToAdditiveModelMap.get(str);
                if (modelValues == null)
                    modelValues = new AdditiveModelValues();

                modelValues.additiveModelMatrix[domainResIx][peptideResIx] = value;
                domainPeptidePosPairToAdditiveModelMap.put(str, modelValues);
            }

            // Output matrix values
            /*Set keys = domainPeptidePosPairToAdditiveModelMap.keySet();
            Iterator it = keys.iterator();
            int domainIx = 2;
            while(it.hasNext())
            {
                String domainPepPosPair = (String) it.next();
                AdditiveModelValues modelValues = (AdditiveModelValues)domainPeptidePosPairToAdditiveModelMap.get(domainPepPosPair);
                double[][] valueMatrix  = modelValues.additiveModelMatrix;
                double avgvalue = 0.0;
                for (int i=0; i < valueMatrix[domainIx].length;i++)
                {
                    double addvalue = valueMatrix[domainIx][i];
                    avgvalue = avgvalue+ addvalue;
                }
                System.out.println("avg value: " + avgvalue/20);
            }     */
            br.close();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
            e.printStackTrace();
        }
        enc = new Chen16FeatureEncoding();
        System.out.println("\tUsing CHEN 16 contact map...");
        predictorName = "Additive";
    }
    public void train()
    {
        System.out.println("\tTraining...");
    }
    public List getPredictions()
    {
        return predictionList;
    }
    public static void main(String[] args)
    {

        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        String testName = "";
        String dirName = "";
        //dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName = "PM WORM G"; dirName = "PMWormG";
        //dl.loadMouseTest("ORPHAN"); testName = "MOUSE ORPHAN PM";  dirName = "PMMouseOrphanG";
        dl.loadFlyTest(); testName = "FLY PM"; dirName = "PMFlyG";
        List posTestProfileList = dl.getPosTestProfileList();
        List negTestProfileList = dl.getNegTestProfileList();

        AdditivePredictor a = new AdditivePredictor(posTrainProfileList, negTrainProfileList);

        List predictionList = a.predict(posTestProfileList, negTestProfileList);
        String predictorName = a.getPredictorName();
        Evaluation eval = new Evaluation(predictionList);
        List rocAUCList = new ArrayList();
        rocAUCList.add(eval.getROCAUC());
        List prAUCList = new ArrayList();
        prAUCList.add(eval.getPRAUC());
        List aucLabelList = new ArrayList();
        aucLabelList.add(predictorName);
        Instances inst = eval.getCurve(1);
        List instList = new ArrayList();
        instList.add(inst);
        System.out.println();
        System.out.println("=== Summary " +testName+ " ("+predictorName+") ===");
        System.out.println(eval.toString());

        List list = new ArrayList();
        list.add(predictionList);

        plotCurves(instList,rocAUCList, prAUCList, aucLabelList,predictorName + " (" +testName+")");

    }

    public List predict(List posTestProfileList, List negTestProfileList)
    {
        // Clear prediction list!
        predictionList = new ArrayList();
        Data testData = new Data();
        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        List balPosProfileList = new ArrayList();
        HashMap testPosProfileHashMap = PDZSVMUtils.profileListToHashMap(posTestProfileList);
        List balNegProfileList = new ArrayList();
        HashMap testNegProfileHashMap = new HashMap();
        if (negTestProfileList !=  null && !negTestProfileList.isEmpty())
        {
            testData.addRawData(negTestProfileList, Constants.CLASS_NO);
            testNegProfileHashMap = PDZSVMUtils.profileListToHashMap(negTestProfileList);
        }
        for (int i =0; i < negTestProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negTestProfileList.get(i);
            ProteinProfile posProfile = (ProteinProfile)testPosProfileHashMap.get(negProfile.getName());
            if (posProfile == null)
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(null);
            }
            else
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(posProfile);
            }
        }
        for (int i =0; i < posTestProfileList.size();i++)
        {
            ProteinProfile posProfile = (ProteinProfile)posTestProfileList.get(i);
            ProteinProfile negProfile = (ProteinProfile)testNegProfileHashMap.get(posProfile.getName());
            if (negProfile == null)
            {
                balNegProfileList.add(null);
                balPosProfileList.add(posProfile);
            }

        }

        ProteinProfile profile;
        for (int i =0;i < balPosProfileList.size();i++)
        {
            ProteinProfile testPosProfile = (ProteinProfile)balPosProfileList.get(i);
            ProteinProfile testNegProfile = (ProteinProfile)balNegProfileList.get(i);
            if (testPosProfile !=null)
                profile = testPosProfile;
            else
                profile = testNegProfile;

            String domainSeqFull = profile.getDomainSequence();
            String organismLong = profile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            String domainSeq16 = enc.getFeatures(domainSeqFull, organism);
            String name = profile.getName();
            String methodLong = profile.getExperimentalMethod();
            String method = PDZSVMUtils.methodLongToShortForm(methodLong);

            if (testPosProfile!=null)
            {

                Collection testSeqCollection = testPosProfile.getSequenceMap();

                Iterator it = testSeqCollection.iterator();
                while(it.hasNext())
                {
                    Sequence peptideSeq = (Sequence)it.next();
                    String peptide = peptideSeq.seqString();
                    peptide = peptide.substring(peptide.length()-5,peptide.length());
                    double value = computeAdditiveValue(domainSeq16,peptide);
                    //System.out.println(domainSeq16 + "\t" + peptide+"\t"+value);

                    double score = value;

                    Prediction pred;
                    if (value > TAU)
                    {
                        pred = new Prediction(1.0,1,score, name, domainSeqFull, peptideSeq.seqString(), organism, method);
                    }
                    else
                    {
                        pred = new Prediction(0.0,1,score, name, domainSeqFull, peptideSeq.seqString(), organism, method);
                    }
                    predictionList.add(pred);

                }
            }
            if (testNegProfile!=null)
            {
                Collection testSeqCollection = testNegProfile.getSequenceMap();
                Iterator it = testSeqCollection.iterator();
                while(it.hasNext())
                {
                    Sequence peptideSeq = (Sequence)it.next();
                    String peptide = peptideSeq.seqString();
                    double value = computeAdditiveValue(domainSeq16,peptide);
                    double score = value;
                    Prediction pred;
                    if (value <= TAU)
                    {
                        pred = new Prediction(0.0,0,score, name, domainSeqFull, peptideSeq.seqString(), organism, method);
                    }
                    else
                    {
                        pred = new Prediction(1.0,0,score, name,domainSeqFull, peptideSeq.seqString(), organism, method);
                    }
                    predictionList.add(pred);


                }
            }
        }

        return predictionList;
    }

    private double computeAdditiveValue(String domain, String peptide)
    {
        //String out = "";
        double value = 0.0;
        StringBuffer buffer = new StringBuffer(peptide);
        buffer = buffer.reverse();
        String revpeptide = buffer.toString();
        Set keys = domainPeptidePosPairToAdditiveModelMap.keySet();
        List keyList = new ArrayList();
        keyList.addAll(keys);
        Collections.sort(keyList);
        for (int i=0; i < keyList.size();i++)
        {
            String domainPeptidePosKey = (String)keyList.get(i);
            String[] splitKey = domainPeptidePosKey.split("-");
            int domainPos = Integer.parseInt(splitKey[0]);
            int peptidePos = Integer.parseInt(splitKey[1]);
            //System.out.println(domainPos + "\t" + domain +"\t" +peptidePos + "\t" + peptide +"\t"+revpeptide);
            char domainRes = domain.charAt(domainPos);
            char peptideRes = revpeptide.charAt(peptidePos);

            int domainIx = alphabet.indexOf(domainRes);
            int peptideIx = alphabet.indexOf(peptideRes);
            // if (the residue is an X or a - set the value to be zero)
            if (peptideIx == -1)
            {
                //System.out.println("*"+peptideRes);
                return 0;

            }
            else if (domainIx == -1)
            {
                //System.out.println("*"+domainRes);
                return 0;

            }
            else
            {
                AdditiveModelValues modelValues= (AdditiveModelValues)domainPeptidePosPairToAdditiveModelMap.get(domainPeptidePosKey);
                double modelValue = modelValues.additiveModelMatrix[domainIx][peptideIx];
                value = value + modelValue;
                ///System.out.println((domainPos+1)+ "\t" + (-1*peptidePos)+"\t"+domainRes + "\t" + peptideRes + "\t" + modelValue);
                //out = out + modelValue + " + ";
            }
        }
        //System.out.println(out);
        return value;

    }

    public HashMap kFoldCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();
    }

    public HashMap leaveOutCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();
    }

}

