package org.baderlab.pdzsvm.analysis;

import org.baderlab.pdzsvm.data.DataLoader;
import org.baderlab.pdzsvm.data.manager.DataFileManager;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;

import java.util.*;
import java.util.List;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;

import org.baderlab.pdzsvm.predictor.svm.ContactMapSVMPredictor;
import libsvm.svm_parameter;
import org.baderlab.pdzsvm.evaluation.Evaluation;
import org.baderlab.pdzsvm.utils.Constants;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * This file compares SVMs trained using different artificial negatives
 * discussed in the paper.  Produces R files for Fig. 3c.
 */
public class CompareArtificialNegatives
{
    private int testNum ;
    private String dirName ;
    private boolean print = true;
    private List posTestProfileList;
    private List negTestProfileList;
    private DataLoader dl;
    private String parentDir = "/ArtificialNegatives";
    private int numIt = 100;

    public void loadTests(int test)
    {
        String testName = "";
        dl.clearTest();
        if (test==5)
        {
            dl.loadMouseTest(Constants.MOUSE_ORPHAN);  testName = "MOUSE ORPHAN PM";
            dirName = "/PMMouseOrphanG";

        }
        else if (test==6)
        {
            dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName = "WORM PM";
            dirName = "/PMWormG";

        }
        else if (test==7)
        {
            dl.loadFlyTest();
            dirName = "/PMFlyG"; testName = "FLY PM";

        }

        posTestProfileList = dl.getPosTestProfileList();
        negTestProfileList= dl.getNegTestProfileList();
        System.out.println("\tTEST IS: " + testName);
        
    }

    public CompareArtificialNegatives(int test)
    {
        dl = new DataLoader();
        testNum = test;
    }
    public static  void main(String[] args)
    {
        CompareArtificialNegatives cp = new CompareArtificialNegatives(Integer.parseInt(args[0]));
        cp.compare();
    }
    private void print(String output, String fileName)
    {
        System.out.println("\tWriting to " + DataFileManager.OUTPUT_ROOT_DIR + parentDir + dirName + "/" + fileName + "...");
        try
        {
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File(DataFileManager.OUTPUT_ROOT_DIR + parentDir + dirName + "/" + fileName)));
            bw.write(output);
            bw.close();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
        }
    }
    private void compare()
    {
        System.out.println("\tNum iterations: " + numIt);

        List foldGlobalSVMCMRandomPredictions = new ArrayList();
        List foldGlobalSVMCMShuffledPredictions = new ArrayList();
        List foldGlobalSVMCMRandomSelPredictions = new ArrayList();
        List foldGlobalSVMCMPWMPredictions = new ArrayList();

        loadTests(testNum);
        List predictionList;
        for (int it=0;it < numIt;it++)
        {
            System.out.println("\t=== Run " + it +" ===");

            if (it == 0)
            {
                predictionList = runSVM(Constants.PWM);
                foldGlobalSVMCMPWMPredictions.add(predictionList);
            }
            predictionList = runSVM(Constants.RANDOM);
            foldGlobalSVMCMRandomPredictions.add(predictionList);

            predictionList = runSVM(Constants.SHUFFLED);
            foldGlobalSVMCMShuffledPredictions.add(predictionList);

            predictionList = runSVM(Constants.RANDOM_SEL);
            foldGlobalSVMCMRandomSelPredictions.add(predictionList);

        }

        String predictorName = "Random";
        if (print)
        {
            String rFilename= predictorName.replace(' ','_') + "_Load.r";
            makeRFile(foldGlobalSVMCMRandomPredictions, rFilename);
        }
        predictorName = "Shuffled";
        if (print)
        {
            String rFilename= predictorName.replace(' ','_') + "_Load.r";
            makeRFile(foldGlobalSVMCMShuffledPredictions, rFilename);
        }
        predictorName = "Random Selection";
        if (print)
        {
            String rFilename= predictorName.replace(' ','_') + "_Load.r";
            makeRFile(foldGlobalSVMCMRandomSelPredictions, rFilename);
        }
        
        predictorName = "PWM";

        if (print)
        {
            String rFilename= predictorName.replace(' ','_') + "_Load.r";
            makeRFile(foldGlobalSVMCMPWMPredictions, rFilename);
        }

        System.out.println("\tDone.");
    }

    private List[] getFoldAUCLists(List foldPredictionList)
    {
        List rocAUCList = new ArrayList();
        List prAUCList = new ArrayList();
        for (int i=0; i < foldPredictionList.size();i++)
        {
            List predictionList = (List) foldPredictionList.get(i);
            Evaluation eval = new Evaluation(predictionList);
            double rocAUC = eval.getROCAUC();
            double prAUC = eval.getPRAUC();
            System.out.print("*");

            rocAUCList.add(rocAUC);
            prAUCList.add(prAUC);

        }
        System.out.println();

        return new List[]{rocAUCList, prAUCList};
    }
    private void makeRFile(List foldPredictionList, String fileName)//, double rocAUC, double prAUC)
    {
        System.out.println("\tMaking R file ...");
        StringBuffer rString = PDZSVMUtils.toRString(foldPredictionList);
        rString.append("\n");

        System.out.println("\tGetting fold AUCs...");
        List[] foldAUCLists = getFoldAUCLists(foldPredictionList);
        List foldROCAUCList = foldAUCLists[0];
        List foldPRAUCList = foldAUCLists[1];

        rString.append("rAUC = c(" +foldROCAUCList.get(0));
        for (int i=1; i < foldROCAUCList.size();i++)
        {
            double foldAUC = (Double)foldROCAUCList.get(i);
            rString.append(","+foldAUC);
        }
        rString.append(")\n");

        rString.append("prAUC = c(" +foldPRAUCList.get(0));
        for (int i=1; i < foldPRAUCList.size();i++)
        {
            double foldAUC = (Double)foldPRAUCList.get(i);
            rString.append(","+foldAUC);
        }
        rString.append(")\n");

        print(rString.toString(),fileName);
    }

    private List runSVM(String negMethod)
    {

        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();
        double C = 2;
        double g = 4;
        if (negMethod.equals(Constants.PWM )) // 0
        {
            //C = 6; g = 3;
            System.out.println("\tPWM, [g,C] = ["+g+","+C+"])");
            dl.loadHumanTrain(Constants.PWM);
        }
        else if (negMethod.equals(Constants.RANDOM)) // 1
        {
            //C = 6; g = 3;

            System.out.println("\tRANDOM, [g,C] = ["+g+","+C+"])");
            dl.loadHumanTrain(Constants.RANDOM);
        }
        else if (negMethod.equals(Constants.SHUFFLED)) // 2
        {
            //C = 4; g = 3;
            //C = 3; g = 3;

            System.out.println("\tSHUFFLED, [g,C] = ["+g+","+C+"])");
            dl.loadHumanTrain(Constants.SHUFFLED);
        }
        else if (negMethod.equals(Constants.RANDOM_SEL)) // 3
        {
            //C = 2; g = 4;

            System.out.println("\tRANDOM SEL, [g,C] = ["+g+","+C+"])");
            dl.loadHumanTrain(Constants.RANDOM_SEL);
        }

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();        
        svmParams.C = Math.exp(C);
        svmParams.gamma = Math.exp(-Math.log(2)-g);

        svmParams.data_encoding = svm_parameter.CONTACTMAP2020;

        ContactMapSVMPredictor c = new ContactMapSVMPredictor(posTrainProfileList, negTrainProfileList, svmParams);
        c.train();
        List predictions = c.predict(posTestProfileList, negTestProfileList);

        return predictions;

    }

}

