package org.baderlab.pdzsvm.analysis;

import org.baderlab.pdzsvm.data.DataLoader;
import org.baderlab.pdzsvm.data.manager.DataFileManager;

import java.util.List;
import java.util.ArrayList;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.File;

import org.baderlab.pdzsvm.utils.Constants;
import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.predictor.svm.ContactMapSVMPredictor;
import org.baderlab.pdzsvm.evaluation.Evaluation;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import weka.core.Utils;
import weka.core.Instances;
import libsvm.svm_parameter;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * This file compares SVMs trained using different training data sets
 * discussed in the paper.  Produces R files for Fig. 3a.
 */
public class CompareTrainingData {
    private String testName;
    private int testNum ;
    private String dirName ;
    private boolean print = true;
    private List posTestProfileList;
    private List negTestProfileList;
    private DataLoader dl;
    private String parentDir = "/TrainingData";

    public void loadTests(int test)
    {
        dl.clearTest();

        if (test==5)
        {
            dl.loadMouseTest(Constants.MOUSE_ORPHAN);  testName = "MOUSE ORPHAN PM";
            dirName = "/PMMouseOrphanG";

        }
        else if (test==6)
        {
            dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName ="WORM PM";
            dirName = "/PMWormG";

        }
        else if (test==7)
        {
            dl.loadFlyTest(); testName = "FLY PM";
            dirName = "/PMFlyG";

        }

        posTestProfileList = dl.getPosTestProfileList();
        negTestProfileList= dl.getNegTestProfileList();

        System.out.println("\tTEST IS: " + testName);


    }
    public CompareTrainingData(int test)
    {
        dl = new DataLoader();
        testNum = test;
    }
    public static  void main(String[] args)
    {
        CompareTrainingData cp = new CompareTrainingData(Integer.parseInt(args[0]));
        cp.compare();

    }

    private void print(String output, String fileName)
    {
        System.out.println("\tWriting to " + DataFileManager.OUTPUT_ROOT_DIR + parentDir + dirName + "/" + fileName + "...");
        
        try
        {
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File(DataFileManager.OUTPUT_ROOT_DIR + parentDir + dirName + "/" + fileName)));
            bw.write(output);
            bw.close();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
        }
    }

    private void compare()
    {
        List globalSVMCMNonGenomicPredictions = new ArrayList();
        List globalSVMCMMousePredictions = new ArrayList();
        List globalSVMCMMouseSidhuHumanGenomicPredictions = new ArrayList();
        List globalSVMCMSidhuHumanGenomicPredictions = new ArrayList();
        List globalSVMCMGenomicPredictions = new ArrayList();

        loadTests(testNum);

        globalSVMCMMousePredictions.addAll(runSVM(Constants.MOUSE, Constants.NONE));
        globalSVMCMMouseSidhuHumanGenomicPredictions.addAll(runSVM(Constants.MOUSE+Constants.SIDHU_HUMAN, Constants.NONE));
        globalSVMCMSidhuHumanGenomicPredictions.addAll(runSVM(Constants.SIDHU_HUMAN,Constants.NONE));
        globalSVMCMGenomicPredictions.addAll(runSVM(Constants.MOUSE+Constants.SIDHU_HUMAN,Constants.GENOMIC));
        globalSVMCMNonGenomicPredictions.addAll(runSVM(Constants.MOUSE+Constants.SIDHU_HUMAN,Constants.NON_GENOMIC));

        List predList =new ArrayList();
        List rocAUCList = new ArrayList();
        List prAUCList = new ArrayList();
        List aucLabelList = new ArrayList();

        String predictorName = "SVM CM MOUSE";
        aucLabelList.add(predictorName);
        Evaluation evalCMMouse = evaluate(globalSVMCMMousePredictions, predictorName);
        double rocAUC = Double.parseDouble(Utils.doubleToString(evalCMMouse.getROCAUC(),3));
        rocAUCList.add(rocAUC);
        double prAUC = Double.parseDouble(Utils.doubleToString(evalCMMouse.getPRAUC(),3));
        prAUCList.add(prAUC);
        Instances instCMMouse = evalCMMouse.getCurve(1);
        predList.add(instCMMouse);
        if (print)
        {
            List predictions = new ArrayList();
            predictions.add(globalSVMCMMousePredictions);
            StringBuffer rString = PDZSVMUtils.toRString(predictions);
            print(rString.toString(),"test_" + predictorName.replace(' ','_') + "_Load.r");
        }

        predictorName = "SVM CM MOUSE HUMAN";
        aucLabelList.add(predictorName);
        Evaluation evalCMMouseSidhuHuman = evaluate(globalSVMCMMouseSidhuHumanGenomicPredictions, predictorName);
        rocAUC = Double.parseDouble(Utils.doubleToString(evalCMMouseSidhuHuman.getROCAUC(),3));
        rocAUCList.add(rocAUC);
        prAUC = Double.parseDouble(Utils.doubleToString(evalCMMouseSidhuHuman.getPRAUC(),3));
        prAUCList.add(prAUC);
        Instances instCMMouseSidhuHuman = evalCMMouseSidhuHuman.getCurve(1);
        predList.add(instCMMouseSidhuHuman);
        if (print)
        {
            List predictions = new ArrayList();
            predictions.add(globalSVMCMMouseSidhuHumanGenomicPredictions);
            StringBuffer rString = PDZSVMUtils.toRString(predictions);
            print(rString.toString(),"test_" + predictorName.replace(' ','_') + "_Load.r");
        }

        predictorName = "SVM CM MOUSE HUMAN G";
        aucLabelList.add(predictorName);
        Evaluation evalCMGenomic = evaluate(globalSVMCMGenomicPredictions, predictorName);
        rocAUC = Double.parseDouble(Utils.doubleToString(evalCMGenomic.getROCAUC(),3));
        rocAUCList.add(rocAUC);
        prAUC = Double.parseDouble(Utils.doubleToString(evalCMGenomic.getPRAUC(),3));
        prAUCList.add(prAUC);
        Instances instCMGenomic = evalCMGenomic.getCurve(1);
        predList.add(instCMGenomic);
        if (print)
        {
            List predictions = new ArrayList();
            predictions.add(globalSVMCMGenomicPredictions);

            StringBuffer rString = PDZSVMUtils.toRString(predictions);
            print(rString.toString(),"test_" + predictorName.replace(' ','_') + "_Load.r");
        }
        predictorName = "SVM CM MOUSE HUMAN NG";
        aucLabelList.add(predictorName);
        Evaluation evalCMNonGenomic = evaluate(globalSVMCMNonGenomicPredictions, predictorName);
        rocAUC = Double.parseDouble(Utils.doubleToString(evalCMNonGenomic.getROCAUC(),3));
        rocAUCList.add(rocAUC);
        prAUC = Double.parseDouble(Utils.doubleToString(evalCMNonGenomic.getPRAUC(),3));
        prAUCList.add(prAUC);
        Instances instCMNonGenomic = evalCMNonGenomic.getCurve(1);
        predList.add(instCMNonGenomic);
        if (print)
        {
            List predictions = new ArrayList();
            predictions.add(globalSVMCMNonGenomicPredictions);

            StringBuffer rString = PDZSVMUtils.toRString(predictions);
            print(rString.toString(),"test_" + predictorName.replace(' ','_') + "_Load.r");
        }

        predictorName = "SVM CM HUMAN";
        aucLabelList.add(predictorName);
        Evaluation evalCMSidhuHuman = evaluate(globalSVMCMSidhuHumanGenomicPredictions, predictorName);
        rocAUC = Double.parseDouble(Utils.doubleToString(evalCMSidhuHuman.getROCAUC(),3));
        rocAUCList.add(rocAUC);
        prAUC = Double.parseDouble(Utils.doubleToString(evalCMSidhuHuman.getPRAUC(),3));
        prAUCList.add(prAUC);
        Instances instCMSidhuHuman = evalCMSidhuHuman.getCurve(1);
        predList.add(instCMSidhuHuman);
        if (print)
        {
            List predictions = new ArrayList();
            predictions.add(globalSVMCMSidhuHumanGenomicPredictions);
            StringBuffer rString = PDZSVMUtils.toRString(predictions);
            print(rString.toString(),"test_" + predictorName.replace(' ','_') + "_Load.r");
        }

        Predictor.plotCurves(predList, rocAUCList, prAUCList, aucLabelList, testName);

    }

    private Evaluation evaluate(List predictionList, String predictorName)
    {
        Evaluation eval = new Evaluation(predictionList);

        System.out.println("=== Summary " +testName+ " ("+predictorName+") ===");
        System.out.println(eval.toString());
        return eval;
    }

    private List runSVM( String data, String type)
    {
        DataLoader dl = new DataLoader();
        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();

        double C = 2;
        double g = 4;
        if (data.equals(Constants.MOUSE))
        {
            dl.loadMouseChenTrain();
            C = 3.0; g = 5.0;
            System.out.println("\tTraining data type: Mouse Only");
            System.out.println("\tCHEN MOUSE, [g,C] = ["+g+","+C+"])");
        }
        else if (data.equals(Constants.SIDHU_HUMAN))
        {
            dl.loadSidhuHumanTrain(Constants.PWM,Constants.NUM_RED_PEPTIDES);
            C = 3.0; g = 5.0;
            System.out.println("\tTraining data type: Sidhu Human Only");
            System.out.println("\tSIDHU HUMAN, [g,C] = ["+g+","+C+"])");
        }
        else if (data.equals(Constants.MOUSE + Constants.SIDHU_HUMAN))
        {
            if (type.equals(Constants.NONE))
            {
                System.out.println("\tTraining data type: Mouse + Sidhu Human Only");
                System.out.println("\tCHEN MOUSE + SIDHU HUMAN, [g,C] = ["+g+","+C+"])");
                dl.loadMouseChenTrain();
                dl.loadSidhuHumanTrain(Constants.PWM,Constants.NUM_RED_PEPTIDES);
            }
            else
            {
                if (type.equals(Constants.GENOMIC))
                {
                    System.out.println("\tCHEN MOUSE + HUMAN GENOMIC, [g,C] = ["+g+","+C+"])");
                }
                else
                {
                    System.out.println("\tCHEN MOUSE + HUMAN NON GENOMIC, [g,C] = ["+g+","+C+"])");
                    C = 3.0; g = 5.0;
                }
                dl.loadMouseChenTrain();
                dl.loadHumanTrain(Constants.PWM, type);
            }

        }

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();
        svmParams.C = Math.exp(C);
        svmParams.gamma = Math.exp(-Math.log(2)-g);
        svmParams.data_encoding = svm_parameter.CONTACTMAP2020;

        ContactMapSVMPredictor p =new ContactMapSVMPredictor(
                posTrainProfileList,
                negTrainProfileList,
                svmParams);
        p.train();
        List predictions = p.predict(posTestProfileList, negTestProfileList);

        return predictions;

    }

}
