package org.baderlab.pdzsvm.predictor.svm;

import org.baderlab.pdzsvm.evaluation.Evaluation;
import org.baderlab.pdzsvm.evaluation.Prediction;

import java.util.*;
import java.util.List;

import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.data.*;
import libsvm.*;
import weka.core.Instances;
import org.baderlab.pdzsvm.validation.ValidationParameters;
import org.baderlab.pdzsvm.utils.Constants;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * SVM predictor using a contact map feature encoding as described in the paper.
 * TODO: This should be implemented as a generic svm predictor not associated with
 * any specific feature encoding
 */
public class ContactMapSVMPredictor extends Predictor {

    public final static int CHEN16 = 0;
    private svm_model svmModel;
    private svm_parameter svmParams;
    private Data trainData;

    public ContactMapSVMPredictor(List posTrainProfileList, List negTrainProfileList,svm_parameter svmParams)
    {
        super(posTrainProfileList,negTrainProfileList);
        trainData = new Data();
        trainData.addRawData(posTrainProfileList,Constants.CLASS_YES);
        trainData.addRawData(negTrainProfileList,Constants.CLASS_NO);

        this.svmParams = svmParams;

        predictorName = "SVM";
        System.out.println("\tUsing CHEN 16 contact map...");

    }
    public int getNumTrainPositive()
    {
        return trainData.getNumPositive();
    }
    public int getNumTrainNegative()
    {
        return trainData.getNumNegative();
    }
    public svm_parameter getSVMParams()
    {
        return svmParams;
    }

    public List predict(List posTestProfileList, List negTestProfileList)
    {
        Data testData = new Data();
        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        if (!negTestProfileList.isEmpty() && negTestProfileList !=  null)
            testData.addRawData(negTestProfileList, Constants.CLASS_NO);
        predictionList  = SVM.predict(trainData, testData, svmModel, svmParams);

        return predictionList;
    }

    public void train()
    {
        svmModel = SVM.train(trainData, svmParams);

    }

    public static void main(String[] args)
    {
        String dirName;
        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();

        String predictorName = "SVM";
        boolean useHumanTrain = true;
        if (useHumanTrain)
        {
            dl.loadHumanTrain(Constants.PWM);
        }
        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        String testName = "";

        dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName = "WORM PM";  dirName = "PMWormG";
        //dl.loadMouseTest("ORPHAN"); testName = "MOUSE ORPHAN PM";  dirName = "PMMouseOrphanG";
        //dl.loadFlyTest(); testName = "FLY PM";  dirName = "PMFlyG";
        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();
        svmParams.C = Math.exp(2);
        svmParams.gamma = Math.exp(-Math.log(2)-3);
        svmParams.data_encoding = svm_parameter.CONTACTMAP2020;

        List posTestProfileList = dl.getPosTestProfileList();
        List negTestProfileList = dl.getNegTestProfileList();

        ContactMapSVMPredictor c = new ContactMapSVMPredictor(posTrainProfileList, negTrainProfileList, svmParams);
        c.train();
        List predictions = c.predict(posTestProfileList, negTestProfileList);
        List instList = new ArrayList();
        Evaluation eval = new Evaluation(predictions);
        System.out.println(eval.toString());
        Instances inst = eval.getCurve(1);
        instList.add(inst);
        List rocAUCList = new ArrayList();
        rocAUCList.add(eval.getROCAUC());
        List prAUCList = new ArrayList();
        prAUCList.add(eval.getPRAUC());
        List aucLabelList = new ArrayList();
        aucLabelList.add(predictorName);
       
        plotCurves(instList,rocAUCList, prAUCList, aucLabelList,predictorName + " (" +testName+")");
        System.out.println(toRString(predictions));
    }
    public HashMap kFoldCrossValidation(ValidationParameters validParams)
    {
        HashMap cvResultsMap = SVM.kFoldCrossValidation(trainData, svmParams, validParams);
        return cvResultsMap;
    }
    public HashMap leaveOutCrossValidation(ValidationParameters validParams)
    {
        HashMap cvResultsMap = SVM.leaveOutCrossValidation(trainData,svmParams, validParams);
        return cvResultsMap;
    }
    private static String toRString(List predictionList)
    {
        String actual = "";
        String dec = "";
        Prediction pred= (Prediction)predictionList.get(0);

        actual = actual + "actual[[1]]=c("+pred.getActual();
        dec = dec + "dec[[1]]=c("+pred.getDecValue();
        for (int j = 1; j < predictionList.size();j++)
        {
            pred= (Prediction)predictionList.get(j);

            actual = actual + "," + pred.getActual();
            dec = dec + "," + pred.getDecValue();
        }
        actual = actual + ")\n";
        dec = dec + ")\n";


        System.out.println(actual);
        System.out.println(dec);
        String rString = actual + "\n" + dec;
        return rString;

    }

}