package org.baderlab.pdzsvm.validation;

import org.baderlab.pdzsvm.data.DataLoader;
import org.baderlab.pdzsvm.data.manager.DataFileManager;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import org.baderlab.pdzsvm.utils.Constants;

import java.util.*;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.File;

import libsvm.svm_parameter;
import org.baderlab.pdzsvm.predictor.svm.ContactMapSVMPredictor;
import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.predictor.nn.NNPredictor;
import org.baderlab.pdzsvm.evaluation.Evaluation;
import weka.core.Instances;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * CrossValidation calls predictor methods to perform different methods of
 * cross validation. This code generates R code for Figs. 1, 2a, 2b.
 */
public class CrossValidation
{
    private List instList;
    private List aucLabelList;
    private List rocAUCList;
    private List prAUCList;

    private String dirName ;
    private String parentDir = "/CrossValidation";
    private boolean print = true;

    public CrossValidation()
    {
    }
    public void validate(int predictorType, int validationType)
    {
        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();
        dl.loadHumanTrain(Constants.PWM);

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();
        double C = 2; double g = 4;
        System.out.println("\tSVM CM, [g,C] = ["+g+","+C+"])");
        svmParams.C = Math.exp(C);
        svmParams.gamma = Math.exp(-Math.log(2)-g);

        svmParams.data_encoding = svm_parameter.CONTACTMAP2020;

        Predictor p;
        if (predictorType == 0)
            p =new ContactMapSVMPredictor(posTrainProfileList,negTrainProfileList,svmParams);
        else
            p = new NNPredictor(posTrainProfileList, negTrainProfileList);

        ValidationParameters validParams = new ValidationParameters();
        if (validationType==ValidationParameters.K_FOLD)
        {
            validParams.type = ValidationParameters.K_FOLD;    dirName = "/KFold";   validParams.k = 10;   validParams.numTimes = 10;
        }
        else if (validationType==ValidationParameters.DOMAIN)
        {
            validParams.type = ValidationParameters.DOMAIN;  dirName = "/Domain";
            validParams.d = 12; validParams.k = 10;  validParams.numTimes = 10;
        }
        else if (validationType==ValidationParameters.LOOV_DOMAIN)
        {
            validParams.type = ValidationParameters.DOMAIN;  dirName = "/Domain";
            validParams.d = 0; validParams.p = 0; validParams.k = 1; validParams.numTimes = 1;
        }
        else if (validationType==ValidationParameters.PEPTIDE)
        {
            validParams.type = ValidationParameters.PEPTIDE; dirName = "/Peptide";
            validParams.p = 8; validParams.k = 10;  validParams.numTimes = 10;
        }
        else if (validationType == ValidationParameters.LOOV_PEPTIDE)
        {
            validParams.type = ValidationParameters.PEPTIDE; dirName = "/Peptide";
            validParams.d = 0; validParams.p = 0; validParams.k = 1; validParams.numTimes = 1;
        }
        else if (validationType == ValidationParameters.DOMAIN_PEPTIDE)
        {
            validParams.type = ValidationParameters.DOMAIN_PEPTIDE; dirName = "/DomainPeptide";
            validParams.d = 12;validParams.p = 8;  validParams.k = 10;  validParams.numTimes = 10;
        }
        else
        {
            System.out.println("\tUnkonwn validation type...exiting.");
            return;
        }
        instList = new ArrayList();
        aucLabelList = new ArrayList();
        rocAUCList = new ArrayList();
        prAUCList= new ArrayList();

        runValidation(p, validParams);

        String title = ValidationParameters.CV_STRING[validationType];
        Predictor.plotCurves(instList, rocAUCList, prAUCList, aucLabelList, title);

    }

    private void print(String output, ValidationParameters params, String fileName)
    {
        try
        {
            String outFileName = params.outputDir + "/" + fileName;

            System.out.println("\tWriting to " + outFileName+ "...");

            BufferedWriter bw = new BufferedWriter(new FileWriter(new File(outFileName)));
            bw.write(output);
            bw.close();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
        }
    }

    private void runValidation(Predictor p, ValidationParameters validParams)
    {
        validParams.outputDir = DataFileManager.OUTPUT_ROOT_DIR + parentDir + dirName;
        validParams.predictorName = p.getPredictorName();


        String validationType = ValidationParameters.CV_STRING[validParams.type];
        List predictionList =  new ArrayList();
        List foldROCAUCList = new ArrayList();
        List foldPRAUCList = new ArrayList();
        List foldPredictionList = new ArrayList();
        for (int ii=0;ii< validParams.numTimes;ii++)
        {
            HashMap cvResultsMap;
            System.out.println("\t=== Run # " + (ii+1) + " ===");
            if (validParams.type==ValidationParameters.K_FOLD)
                cvResultsMap = p.kFoldCrossValidation(validParams);
            else
                cvResultsMap = p.leaveOutCrossValidation(validParams);

            Set keys = cvResultsMap.keySet();
            List keyList = new ArrayList(keys);
            Collections.sort(keyList);
            for (int i=0; i < keyList.size();i++)
            {
                Integer foldNum = (Integer)keyList.get(i);
                List cvPredictionList = (List)cvResultsMap.get(foldNum);
                predictionList.addAll(cvPredictionList);
                Evaluation eval = new Evaluation(cvPredictionList);
                double rocAUC = eval.getROCAUC();
                double prAUC = eval.getPRAUC();

                //System.out.println(auc);
                foldROCAUCList.add(rocAUC);
                foldPRAUCList.add(prAUC);

                foldPredictionList.add(cvPredictionList);
            }
        }
        Evaluation eval = new Evaluation(predictionList);
        double rocAUC = eval.getROCAUC();
        double prAUC = eval.getPRAUC();

        System.out.println("\tOverall rAUC: " + rocAUC);
        System.out.println("\tOverall prAUC: " + prAUC);

        Instances inst = eval.getCurve(1);
        instList.add(inst);
        aucLabelList.add(p.getPredictorName() +" " + validationType);
        rocAUCList.add(rocAUC);
        prAUCList.add(prAUC);


        System.out.print("rAUC = c(" +foldROCAUCList.get(0));
        for (int i=1; i < foldROCAUCList.size();i++)
        {
            double foldAUC = (Double)foldROCAUCList.get(i);
            System.out.print(","+foldAUC);
        }
        System.out.println(")");

        System.out.print("prAUC = c(" +foldPRAUCList.get(0));
        for (int i=1; i < foldPRAUCList.size();i++)
        {
            double foldAUC = (Double)foldPRAUCList.get(i);
            System.out.print(","+foldAUC);
        }
        System.out.println(")");

        double[] cirAUC = confidenceInterval(foldROCAUCList);
        System.out.println("\tROC AUC: 95% C.I.: " + cirAUC[0] + "~" + cirAUC[1]);

        double[] ciprAUC = confidenceInterval(foldPRAUCList);
        System.out.println("\tPR  AUC: 95% C.I.: " + ciprAUC[0] + "~" + ciprAUC[1]);

        String cvString = ValidationParameters.CV_STRING[validParams.type];

        if (print)
        {
            String fileName = p.getPredictorName().replace(' ','_') + "_"+cvString;
            if (validParams.d==0 && validParams.type == ValidationParameters.DOMAIN)
                fileName = fileName + "_LODO";
            if (validParams.p==0&& validParams.type == ValidationParameters.PEPTIDE)
                fileName = fileName + "_LOPO";
            validParams.predictorName = p.getPredictorName();
            StringBuffer rString = PDZSVMUtils.toRString(foldPredictionList);
            print(rString.toString(),validParams, "cv_" + fileName + "_Load.r");
        }
    }
    public double[] confidenceInterval(List statList)
    {
        double confidenceLevel =0.95;
        double c =1.96;
        double n = statList.size();
        double mean = 0.0;
        for (int i=0; i < statList.size();i++)
        {
            double auc = (Double)statList.get(i);
            mean = mean + auc;
        }
        mean = mean/n;

        double sum = 0.0;
        for (int i=0; i < statList.size();i++)
        {
            double auc = (Double)statList.get(i);
            double diff = auc-mean;
            sum = sum + Math.pow(diff,2);
        }
        double var = (1/(n-1)) * sum;
        double s = Math.sqrt(var);

        double ciLow = mean-c*s/Math.sqrt(n);
        double ciHigh = mean+c*s/Math.sqrt(n);

        return new double[]{ciLow, ciHigh};

    }


    public static void main(String[] args)
    {
        CrossValidation cv = new CrossValidation();
        int predictorType = Integer.parseInt(args[0]);
        if (predictorType == 0)
            System.out.println("\tPredictor type: ContactMapSVMPredictor");
        else
            System.out.println("\tPredictor type: NNPredictor");
        int validationType = Integer.parseInt(args[1]);

        System.out.println("\tValidation type: " + ValidationParameters.CV_STRING[validationType]);
        cv.validate(predictorType, validationType);

    }
}
