package org.baderlab.pdzsvmstruct.optimize;

import org.baderlab.pdzsvmstruct.data.DataLoader;
import org.baderlab.pdzsvmstruct.predictor.svm.GlobalSVMPredictor;
import org.baderlab.pdzsvmstruct.predictor.Predictor;

import java.util.*;

import libsvm.svm_parameter;
import org.baderlab.pdzsvmstruct.validation.ValidationParameters;
import org.baderlab.pdzsvmstruct.utils.Constants;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Grid search program for the regular global svm predictor only
 */
public class OptimizeGlobalPredictor
{
    public OptimizeGlobalPredictor()
    {


    }
    public GridResult optimize(int encoding)
    {
        DataLoader dl = new DataLoader();
        dl.loadMousePDBTrain();
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.PHAGE_DISPLAY);

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();
        
        System.out.println("TRAIN DOMAINS");
        System.out.println("Num Pos Domains: " + posTrainProfileList.size());
        System.out.println("Num Neg Domains: " + negTrainProfileList.size());

        System.out.println("TRAIN INTERACTIONS");
        System.out.println("Num Pos Interactions:" + dl.getNumPosTrainInteractions());
        System.out.println("Num Neg Interactions:" + dl.getNumNegTrainInteractions());

        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();

        /* kernel_type */
        // public static final int CHAR_SEQUENCE = 0;
        // public static final int BINARY_SEQUENCE = 1;
        // public static final int PHYSICOCHEMICAL =2;
        // public static final int CHEN_CONTACTMAP = 3;
        // public static final int P0 = 4;
        // public static final int PAIRWISE_POTENTIAL = 5;
        // public static final int STRUCT = 6;
        // public static final int SIDHU_CONTACTMAP = 7;
        // public static final int STRUCTXBSQ = 8;
        // public static final int STRUCTXCM = 9;
        // public static final int SIDHU_STRUCTSEQ = 10;

        svmParams.data_encoding = encoding;
        double[] lnC = new double[]{3,4,5,6,7};
        double[] lnG = new double[]{3,4,5};
        GlobalSVMPredictor gp =new GlobalSVMPredictor(posTrainProfileList,
                negTrainProfileList,
                svmParams);
        List gridResultList;
        if (svmParams.kernel_type == svm_parameter.LINEAR)
            gridResultList = optimizeLINEAR(gp, lnC);
        else
            gridResultList = optimizeRBF(gp,lnG, lnC);

        // Sort grid results
        for (int i=0; i < gridResultList.size();i++)
        {
            GridResult result = (GridResult)gridResultList.get(i);
            if (i==0)
                System.out.println(result.headerString());
            System.out.println(result.toString());

        }

        GridResult topResult = (GridResult)gridResultList.get(0);
        System.out.println("\tOptimal grid result: "+ topResult.g + ", " + topResult.C);
        return topResult;
    }

    public List optimizeLINEAR(Predictor predictor, double[] lnC)
    {
        GlobalSVMPredictor globalSVMPredictor = (GlobalSVMPredictor)predictor;

        svm_parameter svmParams = globalSVMPredictor.getSVMParams();

        ValidationParameters validParams = new ValidationParameters();
        validParams.k = 10;
        validParams.numTimes = 1;
        validParams.type = ValidationParameters.K_FOLD;
        double[] C = GridSearchUtils.getC(lnC);
        List gridResultList= new ArrayList();
        for (int j = 0; j < C.length;j++)
        {
            System.out.println("\t==== Grid ["+ lnC[j] + "] ===\n");

                svmParams.C = C[j];

                svmParams.print();

                HashMap cvResultsMap = globalSVMPredictor.kFoldCrossValidation(validParams);
                GridResult result = GridSearchUtils.computeAvgGridResult(cvResultsMap);

                result.C = GridSearchUtils.toLnC(C[j]);

                System.out.println(result.toString());

                gridResultList.add(result);


        }
        Collections.sort(gridResultList, new MyComparator());

        return gridResultList;
    }
    public List optimizeRBF(Predictor predictor, double[] lnG,double[] lnC)
    {
        System.out.println("\t=== Optimizing RBF Kernel ===\n");
        GlobalSVMPredictor globalSVMPredictor = (GlobalSVMPredictor)predictor;

        svm_parameter svmParams = globalSVMPredictor.getSVMParams();

        double[] C = GridSearchUtils.getC(lnC);
        double[] g = GridSearchUtils.getG(lnG);
        ValidationParameters validParams = new ValidationParameters();
        validParams.k = 10;
        validParams.numTimes = 1;
        validParams.type = ValidationParameters.K_FOLD;
        List gridResultList= new ArrayList();
        for (int i=0; i < g.length;i++)
        {
            for (int j = 0; j < C.length;j++)
            {
                System.out.println("\t==== Grid [" + lnG[i] + "," + lnC[j] + "] ===\n");

                svmParams.C = C[j];
                svmParams.gamma = g[i];

                svmParams.print();

                HashMap cvResultsMap = globalSVMPredictor.kFoldCrossValidation(validParams);
                GridResult result = GridSearchUtils.computeAvgGridResult(cvResultsMap);

                result.C = GridSearchUtils.toLnC(C[j]);
                result.g = GridSearchUtils.toLnG(g[i]);

                System.out.println(result.toString());

                gridResultList.add(result);

            }
        }

        Collections.sort(gridResultList, new MyComparator());

        return gridResultList;
    }
    public class MyComparator implements Comparator
    {
        public int compare(Object anotherResult2, Object anotherResult) throws ClassCastException
        {
            if (!(anotherResult instanceof GridResult))
                throw new ClassCastException("A GridResult object expected.");
            double auc = ((GridResult) anotherResult).rocAUC;
            double auc2 = ((GridResult) anotherResult2).rocAUC;
            if (auc > auc2) return 1;
            else if (auc < auc2) return -1;
            else return 0;
        }
    }
    public static void main(String[] args)
    {
        OptimizeGlobalPredictor opt = new OptimizeGlobalPredictor();
        int encoding = Integer.parseInt(args[0]);
        //public static final int CHAR_SEQUENCE = 0;
        //public static final int BINARY_SEQUENCE = 1;
        //public static final int PHYSICOCHEMICAL =2;
        //public static final int CHEN_CONTACTMAP = 3;
        //public static final int P0 = 4;
        //public static final int PAIRWISE_POTENTIAL = 5;
        //public static final int STRUCT = 6;
        //public static final int SIDHU_CONTACTMAP = 7;

        System.out.println("\tEncoding used: " + svm_parameter.dataEncoding[encoding]);
        GridResult result = opt.optimize(encoding);

    }
}
