package org.baderlab.pdzsvmstruct.predictor.svm;

import org.baderlab.pdzsvmstruct.data.*;

import java.util.*;
import java.util.List;

import libsvm.svm_parameter;
import libsvm.svm_model;
import org.baderlab.pdzsvmstruct.predictor.Predictor;
import org.baderlab.pdzsvmstruct.encoding.*;
import org.baderlab.pdzsvmstruct.validation.ValidationParameters;
import org.baderlab.pdzsvmstruct.utils.Constants;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * SVM predictor using binary sequence or factor feature encoding as described
 * in the paper.
 * TODO: This should be implemented as a generic svm predictor not associated with
 * any specific feature encoding
 */
public class GlobalSVMPredictor extends Predictor {

    private svm_parameter svmParams;
    private Data testData;
    private svm_model svmModel;
    private int encoding;
    private List maxMinList = null;
    private int excludeFeature = DomainFeatureEncoding.ZER;
    public GlobalSVMPredictor(List posTrainProfileList,
                              List negTrainProfileList,
                              svm_parameter svmParams)
    {
        super(posTrainProfileList, negTrainProfileList);
        EXCLUDE_FEATURE = excludeFeature;
        
        this.svmParams = svmParams;
        this.svmParams.nr_weight =2;
        this.svmParams.weight_label = new int[]{-1,1};

        trainData = new Data();
        trainData.addRawData(posTrainProfileList,Constants.CLASS_YES);
        trainData.addRawData(negTrainProfileList, Constants.CLASS_NO);
        if (!trainData.isEmpty())
        {
            // encode data
            if(svmParams.data_encoding==svm_parameter.STRUCT)
            {
                List organismList= trainData.getOrganismList();

                DomainFeatureEncoding dEnc = new DomainFeatureEncoding(organismList, EXCLUDE_FEATURE);
                PeptideFeatureEncoding pEnc = new PeptideFeatureEncoding();

                trainData.encodeBindingSiteStructureData(dEnc, pEnc);
                maxMinList = trainData.getMaxMin();
                trainData.scaleData(0.0,1.0,maxMinList);

                predictorName = "SVM STRUCT";
            }

        }
        this.svmParams.weight = new double[]{ 1.0, 1.0*((double)trainData.getNumPositive()/trainData.getNumNegative())};
        System.out.println(trainData.getNumPositive() +"," + trainData.getNumNegative() );
        System.out.println("\tClass weights: " + svmParams.weight[0] +"," + svmParams.weight[1]);

        trainData.printSummary();
    }

    public void setTrainData(Data trainData)
    {
        this.trainData = trainData;
    }
    public void setSVMParams(svm_parameter svmParams)
    {
        this.svmParams = svmParams;
    }
    
    public svm_parameter getSVMParams()
    {
        return svmParams;
    }
    public Data getTrainData()
    {
        return trainData;
    }
    public HashMap kFoldCrossValidation(ValidationParameters validParams)
    {
        HashMap cvResultsMap = SVM.kFoldCrossValidation(trainData, svmParams, validParams);
        return cvResultsMap;
    }
    public HashMap leaveOutCrossValidation(ValidationParameters validParams)
    {
        HashMap cvResultsMap = SVM.leaveOutCrossValidation(trainData, svmParams, validParams);
        return cvResultsMap;
    }
    public void train ()
    {
        System.out.println("\n\tTraining ...");
        svmModel = SVM.train(trainData, svmParams);
    }
    public List predict(List posTestProfileList, List negTestProfileList)
    {
        testData = new Data();
        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        testData.addRawData(negTestProfileList,Constants.CLASS_NO);

        if (!testData.isEmpty())
        {
            if(svmParams.data_encoding==svm_parameter.STRUCT)
            {
                List organismList= testData.getOrganismList();
                DomainFeatureEncoding dEnc = new DomainFeatureEncoding(organismList, EXCLUDE_FEATURE);
                PeptideFeatureEncoding pEnc = new PeptideFeatureEncoding();

                testData.encodeBindingSiteStructureData(dEnc, pEnc);
                testData.scaleData(0.0,1.0,maxMinList);

                predictorName = "SVM STRUCT";
            }

        }
        System.out.println("\n\tPredicting ...");
        predictionList = SVM.predict(trainData, testData, svmModel,svmParams);
        return predictionList;
    }

    public int getNumTrainPositive()
    {
        return trainData.getNumPositive();
    }
    public int getNumTrainNegative()
    {
        return trainData.getNumNegative();
    }
    public static void main(String[] args)
    {
        DataLoader dl = new DataLoader();
        dl.loadMousePDBTrain();
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.PHAGE_DISPLAY);
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.HOMOLOGY_MODEL);
        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();
        svmParams.data_encoding = svm_parameter.STRUCT;
        double C = 4; double g = 3;
        svmParams.C = Math.exp(C);
        svmParams.gamma = Math.exp(-Math.log(2)-g);

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        GlobalSVMPredictor gp =new GlobalSVMPredictor(posTrainProfileList,
                negTrainProfileList,
                svmParams);

        gp.train();
        
    }


}
