package org.baderlab.pdzsvm.validation;

import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import org.baderlab.pdzsvm.data.Data;
import org.baderlab.pdzsvm.data.Datum;

import java.util.*;

import org.baderlab.pdzsvm.encoding.Chen16FeatureEncoding;
import org.baderlab.pdzsvm.encoding.Features;
import org.baderlab.brain.ProteinProfile;
import org.biojava.bio.seq.Sequence;
import org.baderlab.pdzsvm.evaluation.Prediction;
import org.baderlab.pdzsvm.evaluation.Evaluation;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Prints fold output during cross validation
 */
public class ValidationFoldOutput {

    private List trainDomainSeqList;
    private List trainPeptideSeqList;

    public ValidationFoldOutput(List posTrainProfileList, List negTrainProfileList)
    {
        trainDomainSeqList = getTrainDomains(posTrainProfileList, negTrainProfileList);
        System.out.println("\tNUM training domains: " + trainDomainSeqList.size());
        trainPeptideSeqList = getTrainPeptides(posTrainProfileList, negTrainProfileList);
        System.out.println("\tNUM training peptides: " + trainPeptideSeqList.size());

    }


    public ValidationFoldOutput(Data trainData)
    {
        trainDomainSeqList = getTrainDomains(trainData);
        trainPeptideSeqList = getTrainPeptides(trainData);
    }

    private List getTrainDomains(Data trainData)
    {
        Chen16FeatureEncoding enc = new Chen16FeatureEncoding();
        List domainBindingSiteSeqList = new ArrayList();
        HashMap domainNumToRawMap = trainData.getDomainNumToRawMap();

        List dataList= trainData.getDataList();
        int prevDomainNum = -1;
        int domainNum = -2;
        for (int i=0 ; i < dataList.size();i++)
        {
            Datum dt = (Datum)dataList.get(i);
            String organism = dt.organism;
            domainNum = dt.domainNum;
            Features domainFeatures = (Features)domainNumToRawMap.get(domainNum);
            String domainSeq = domainFeatures.toUndelimitedString();
            domainSeq = enc.getFeatures(domainSeq, organism);
            if (prevDomainNum != domainNum)
                domainBindingSiteSeqList.add(domainSeq);
            prevDomainNum = domainNum;
        }
        System.out.println("\t" + domainBindingSiteSeqList.size());
        return domainBindingSiteSeqList;
    }

    private List getTrainPeptides(Data trainData)
    {
        HashMap peptideRawToNumMap = trainData.getPeptideRawToNumMap();
        List peptideSeqList = new ArrayList();

        Set seqSet = peptideRawToNumMap.keySet();
        List seqList = new ArrayList(seqSet);
        for (int i=0; i < seqList.size();i++)
        {
            Features peptideFeatures = (Features)seqList.get(i);
            String peptideSeq = peptideFeatures.toUndelimitedString();

            if (!peptideSeqList.contains(peptideSeq))
                peptideSeqList.add(peptideSeq);
        }
        return peptideSeqList;
    }


    private List getTrainDomains(List posTrainProfileList, List negTrainProfileList)
    {
        List domainNameList = new ArrayList();

        Chen16FeatureEncoding enc = new Chen16FeatureEncoding();
        List domainBindingSiteSeqList = new ArrayList();
        for (int i = 0;i < posTrainProfileList.size();i++)
        {
            ProteinProfile posProfile = (ProteinProfile)posTrainProfileList.get(i);
            String domainSeq = posProfile.getDomainSequence();
            String organismLong = posProfile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            domainSeq = enc.getFeatures(domainSeq, organism);
            domainNameList.add(posProfile.getName());
            domainBindingSiteSeqList.add(domainSeq);
        }
        for (int i = 0;i < negTrainProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negTrainProfileList.get(i);
            if (negProfile==null)
                continue;
            String domainSeq = negProfile.getDomainSequence();
            String organismLong = negProfile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            domainSeq = enc.getFeatures(domainSeq, organism);
            String domainName = negProfile.getName();
            if (!domainBindingSiteSeqList.contains(domainSeq) && !domainNameList.contains(domainName))
                domainBindingSiteSeqList.add(domainSeq);
        }
        return domainBindingSiteSeqList;
    }


    private List getTrainPeptides(List posTrainProfileList, List negTrainProfileList)
    {
        Chen16FeatureEncoding enc = new Chen16FeatureEncoding();

        List peptideSeqList = new ArrayList();
        for (int i = 0;i < posTrainProfileList.size();i++)
        {
            ProteinProfile posProfile = (ProteinProfile)posTrainProfileList.get(i);
            String domainSeq = posProfile.getDomainSequence();
            String organismLong = posProfile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            domainSeq = enc.getFeatures(domainSeq, organism);

            Collection seqCollection = posProfile.getSequenceMap();
            List seqList  = new ArrayList(seqCollection);
            for (int j = 0; j < seqList.size();j++)
            {
                Sequence seq = (Sequence)seqList.get(j);
                String peptideSeq = seq.seqString();
                if (!peptideSeqList.contains(peptideSeq))
                    peptideSeqList.add(peptideSeq);
            }
        }

        for (int i = 0;i < negTrainProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negTrainProfileList.get(i);
            if (negProfile==null)
                continue;
            String domainSeq = negProfile.getDomainSequence();
            String organismLong = negProfile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            domainSeq = enc.getFeatures(domainSeq, organism);

            Collection seqCollection = negProfile.getSequenceMap();
            List seqList  = new ArrayList(seqCollection);
            for (int j = 0; j < seqList.size();j++)
            {
                Sequence seq = (Sequence)seqList.get(j);
                String peptideSeq = seq.seqString();
                if (!peptideSeqList.contains(peptideSeq))
                    peptideSeqList.add(peptideSeq);
            }
        }
        return peptideSeqList;
    }

    public String domainMismatch(List predictions)
    {
        Chen16FeatureEncoding enc = new Chen16FeatureEncoding();
        String toString = "";
        for (int i=0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction) predictions.get(i);
            String domainSeq = pred.domainSeqFull;
            String organism = pred.organismShort;
            String name = pred.name;
            domainSeq = enc.getFeatures(domainSeq, organism);

            String nnTrainDomainSeq = getNNSeq(domainSeq, trainDomainSeqList);
            int numMismatch = getNumMismatches(domainSeq, nnTrainDomainSeq);
            double decValue = pred.getDecValue();
            int correct = 1;
            if (pred.getActual() != pred.getPrediction())
                correct = -1;
            //System.out.println(name + "\t" + domainSeq + "\t" + numMismatch+"\t" + decValue + "\t" +correct);
            toString = toString+ name + "\t" + domainSeq + "\t" + numMismatch+"\t" + decValue + "\t" +correct + "\t\n";

        }
        return toString;
    }

    public String peptideMismatch(List predictions)
    {
        StringBuffer toString = new StringBuffer("");
        for (int i=0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction) predictions.get(i);
            String peptideSeq = pred.peptideSeq;
            String name = pred.name;

            String nnTrainPeptideSeq = getNNSeq(peptideSeq, trainPeptideSeqList);
            int numMismatch = getNumMismatches(peptideSeq, nnTrainPeptideSeq);
            int correct = 1;
            if (pred.getActual() != pred.getPrediction())
                correct = -1;
            //System.out.println(name + "\t" + peptideSeq + "\t" + numMismatch+"\t" + correct);
            toString.append(name + "\t" + peptideSeq + "\t" + numMismatch+"\t" + correct + "\n");

        }
        return toString.toString();
    }

    public String byPeptide(List predictions)
    {
        StringBuffer byPeptideString = new StringBuffer();
        // sort by peptides
        HashMap predictionsByPeptideMap = new HashMap();
        for(int i=0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction) predictions.get(i);

            String peptideSeq = pred.peptideSeq;
            List predictionsByPeptideList = (List)predictionsByPeptideMap.get(peptideSeq);
            if (predictionsByPeptideList==null)
                predictionsByPeptideList = new ArrayList();
            predictionsByPeptideList.add(pred);
            predictionsByPeptideMap.put(peptideSeq,predictionsByPeptideList);
        }
        //System.out.println("\t=== By Peptide ===");
        Set keys = predictionsByPeptideMap.keySet();
        List keyList= new ArrayList(keys);

        for (int i=0; i < keyList.size();i++)
        {
            String peptideSeq = (String)keyList.get(i);
            String nnTrainPeptideSeq = getNNSeq(peptideSeq, trainPeptideSeqList);
            int numMismatch = getNumMismatches(peptideSeq, nnTrainPeptideSeq);

            double nnSim = getNNSim(peptideSeq, trainPeptideSeqList);
            List predictionList  = (List)predictionsByPeptideMap.get(peptideSeq);
            String name="";
            int numPos = 0;
            int numNeg = 0;
            for (int j=0;j < predictionList.size();j++)
            {
                Prediction pred = (Prediction)predictionList.get(j);
                name = pred.name;
                if (pred.getActual()==1.0)
                    numPos = numPos +1;
                else
                    numNeg = numNeg +1;

            }
            if (numNeg >0 && numPos >0)
            {
                Evaluation eval= new Evaluation(predictionList);
                double rocAUC = eval.getROCAUC();

                double prAUC = eval.getPRAUC();
                //System.out.println(name + "\t" + peptideSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + auc);
                byPeptideString.append(name + "\t" + peptideSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + rocAUC +"\t" + prAUC + "\n");

            }
        }
        return byPeptideString.toString();
    }
    public static String byPeptide(List predictions, List trainPeptideSeqList)
    {
        StringBuffer byPeptideString = new StringBuffer();
        // sort by peptides
        HashMap predictionsByPeptideMap = new HashMap();
        for(int i=0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction) predictions.get(i);

            String peptideSeq = pred.peptideSeq;
            List predictionsByPeptideList = (List)predictionsByPeptideMap.get(peptideSeq);
            if (predictionsByPeptideList==null)
                predictionsByPeptideList = new ArrayList();
            predictionsByPeptideList.add(pred);
            predictionsByPeptideMap.put(peptideSeq,predictionsByPeptideList);
        }
        //System.out.println("\t=== By Peptide ===");
        Set keys = predictionsByPeptideMap.keySet();
        List keyList= new ArrayList(keys);

        for (int i=0; i < keyList.size();i++)
        {
            String peptideSeq = (String)keyList.get(i);
            String nnTrainPeptideSeq = getNNSeq(peptideSeq, trainPeptideSeqList);
            int numMismatch = getNumMismatches(peptideSeq, nnTrainPeptideSeq);

            double nnSim = getNNSim(peptideSeq, trainPeptideSeqList);
            List predictionList  = (List)predictionsByPeptideMap.get(peptideSeq);
            String name="";
            int numPos = 0;
            int numNeg = 0;
            for (int j=0;j < predictionList.size();j++)
            {
                Prediction pred = (Prediction)predictionList.get(j);
                name = pred.name;
                if (pred.getActual()==1.0)
                    numPos = numPos +1;
                else
                    numNeg = numNeg +1;

            }
            if (numNeg >0 && numPos >0)
            {
                Evaluation eval= new Evaluation(predictionList);
                double rocAUC = eval.getROCAUC();

                double prAUC = eval.getPRAUC();
                //System.out.println(name + "\t" + peptideSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + auc);
                byPeptideString.append(name + "\t" + peptideSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + rocAUC +"\t" + prAUC + "\n");

            }
        }
        return byPeptideString.toString();
    }
    public String byDomain(List predictions)
    {
        String byDomainString = "";
        Chen16FeatureEncoding enc = new Chen16FeatureEncoding();
        // sort by domains
        HashMap predictionsByDomainMap = new HashMap();
        for(int i=0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction) predictions.get(i);

            String domainSeq = pred.domainSeqFull;
            String organism = pred.organismShort;
            domainSeq = enc.getFeatures(domainSeq, organism);
            List predictionsByDomainList = (List)predictionsByDomainMap.get(domainSeq);
            if (predictionsByDomainList==null)
                predictionsByDomainList = new ArrayList();
            predictionsByDomainList.add(pred);
            predictionsByDomainMap.put(domainSeq,predictionsByDomainList);
        }


        //System.out.println("\t=== By Domain ===");
        Set keys = predictionsByDomainMap.keySet();
        List keyList= new ArrayList(keys);
        for (int i=0; i < keyList.size();i++)
        {
            String domainSeq = (String)keyList.get(i);
            String nnTrainDomainSeq = getNNSeq(domainSeq, trainDomainSeqList);
            int numMismatch = getNumMismatches(domainSeq, nnTrainDomainSeq);


            double nnSim = getNNSim(domainSeq, trainDomainSeqList);

            String allSim = getAllSimString(domainSeq, trainDomainSeqList);

            List predictionList  = (List)predictionsByDomainMap.get(domainSeq);
            String name="";
            int numPos = 0;
            int numNeg = 0;
            for (int j=0;j < predictionList.size();j++)
            {
                Prediction pred = (Prediction)predictionList.get(j);
                name = pred.name;
                if (pred.getActual()==1)
                    numPos = numPos +1;
                else
                    numNeg = numNeg +1;

            }

            if (numNeg >0 && numPos >0)
            {
                Evaluation eval= new Evaluation(predictionList);

                double rocAUC = eval.getROCAUC();
                double prAUC = eval.getPRAUC();
                //System.out.println(name + "\t" + domainSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + auc);
                byDomainString = byDomainString +name + "\t" + domainSeq + "\t"+numPos + "\t" + numNeg+"\t" + numMismatch +"\t"+ nnSim + "\t" + rocAUC +"\t"+ prAUC+"\n";
                //byDomainString = byDomainString +name + "\t" + domainSeq + "\t"+numPos + "\t" + numNeg+"\t" +  rocAUC +"\t"+ prAUC+ "\t" + allSim + "\n";

            }

        }
        return byDomainString;
    }

    public static String getNNSeq(String seq, List refSeqList)
    {
        String maxSeq = "";
        double maxSim = Double.MIN_VALUE;
        for (int i=0; i < refSeqList.size();i++)
        {
            String intSeqString = (String)refSeqList.get(i);
            double sim = 0.0;
            sim = PDZSVMUtils.identity(seq, intSeqString);

            if (sim >= maxSim)
            {
                maxSeq = intSeqString;
                maxSim = sim;

            }
        }
        return maxSeq;
    }

    public static int getNumMismatches(String seq1, String seq2)
    {
        if (seq1.length()!= seq2.length())
            return -1;
        int numMismatch = 0;
        for (int i=0; i < seq1.length();i++)
        {
            char ch1 = seq1.charAt(i);
            char ch2= seq2.charAt(i);
            if (ch1 != ch2)
                numMismatch = numMismatch+1;
        }
        return numMismatch;
    }

    public static String getAllSimString(String seq, List refSeqList)
    {

        String allSim = "";
        int it = 0;
        for (int i=0; i < refSeqList.size();i++)
        {
            String intSeqString = (String)refSeqList.get(i);
            double sim = PDZSVMUtils.identity(seq, intSeqString);
            allSim = allSim + "\t" + sim;
            it=  it+1;
        }
        System.out.println("\tNum sim computed: " + it);
        return allSim;
    }

    public static  double getNNSim(String seq, List refSeqList)
    {
        double maxSim = Double.MIN_VALUE;
        for (int i=0; i < refSeqList.size();i++)
        {
            String intSeqString = (String)refSeqList.get(i);
            double sim = 0.0;
            sim = PDZSVMUtils.identity(seq, intSeqString);

            if (sim >= maxSim)
            {
                maxSim = sim;

            }
        }
        return maxSim;
    }
   
}
