package org.baderlab.pdzsvm.evaluation;

import weka.core.*;
import weka.core.Utils;
import weka.classifiers.evaluation.TwoClassStats;

import java.util.*;
import java.util.List;
import java.awt.*;

import org.jfree.chart.JFreeChart;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.chart.renderer.xy.StandardXYItemRenderer;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

import javax.swing.*;
import org.baderlab.pdzsvm.utils.Constants;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import auc.Confusion;
import auc.AUCCalculator;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 * Code copied in part from weka 3.6.1
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Evaluation class to calculate predictor evaluation measures.
 * Inspired by Weka Evaluation class but not a exact copy.  We use
 * Weka TwoClassStat to output a Weka ready arff file but that is all.
 * ROC and PR calculations are done using AUC calculator package:
 * Davis, J. and Goadrich, M. (2006) The relationship between Precision-Recall
 *   and ROC curves. Proc. Int. Conf. Machine Learning. ACM, Pittsburgh,
 *   PA, pp. 233-240.
 * ROC and PR Curve plot examples according to:
 * Fawcett, T. (2006) An introduction to ROC analysis,
 *   Pattern Recogn. Lett., 27, 861-874.
 */
public class Evaluation
{
    // CLASSINDEX =0 is POSITIVE
    // CLASSINDEX =1 is NEGATIVE
    /** The name of the relation used in threshold curve datasets */
    public static final String RELATION_NAME = "ThresholdCurve";
    /** attribute name: True Positives */
    public static final String TRUE_POS_NAME  = "True Positives";
    /** attribute name: False Negatives */
    public static final String FALSE_NEG_NAME = "False Negatives";
    /** attribute name: False Positives */
    public static final String FALSE_POS_NAME = "False Positives";
    /** attribute name: True Negatives */
    public static final String TRUE_NEG_NAME  = "True Negatives";
    /** attribute name: False Positive Rate" */
    public static final String FP_RATE_NAME   = "False Positive Rate";
    /** attribute name: True Positive Rate */
    public static final String TP_RATE_NAME   = "True Positive Rate";

    /** attribute name: False Negative Rate" */
    public static final String FN_RATE_NAME   = "False Negative Rate";
    /** attribute name: True Negative Rate */
    public static final String TN_RATE_NAME   = "True Negative Rate";
    /** attribute name: Precision */
    public static final String PRECISION_NAME = "Precision";
    /** attribute name: Recall */
    public static final String RECALL_NAME    = "Recall";
    /** attribute name: Fallout */
    public static final String FALLOUT_NAME   = "Fallout";
    /** attribute name: FMeasure */
    public static final String FMEASURE_NAME  = "FMeasure";
    /** attribute name: Threshold */
    public static final String THRESHOLD_NAME = "Threshold";

    public static final String ACTUAL_NAME = "Actual";

    private double[] decValues;
    private int[] actual;
    private double[] pred;

    private double[][] m_ConfusionMatrix;
    private int m_NumClasses = 2;
    //private String[] m_ClassNames = new String[] {ClassIndex.YES,ClassIndex.NO};
    private String[] m_ClassNames = Constants.CLASSES;

    private double numPos = 0.0;
    private double numNeg = 0.0;
    public Evaluation(List predictions)
    {
        m_ConfusionMatrix = new double[m_NumClasses][m_NumClasses];
        double[] decValues0 = new double[predictions.size()];
        int[] actual0 = new int[predictions.size()];
        double[] pred0 = new double[predictions.size()];
        for (int i =0; i < predictions.size();i++)
        {
            Prediction prediction = (Prediction)predictions.get(i);
            //System.out.println("pred: " + prediction.toString());
            decValues0[i] = prediction.getDecValue();
            actual0[i] = (int)prediction.getActual();
            pred0[i] = prediction.getPrediction();
        }

        int[] sortIncr = Utils.sort(decValues0);
        int[] sortedDecr = new int[decValues0.length];
        int num = decValues0.length;
        for (int i = 0;i < num;i++)
        {
            sortedDecr[num-(i+1)] = sortIncr[i];
        }
        int[] actualDecr = new int[num];
        double[] decValuesDecr = new double[num];
        double[] predDecr = new double[num];
        for (int i=0; i < num;i++)
        {
            decValuesDecr[i] = decValues0[sortedDecr[i]];
            actualDecr[i] = actual0[sortedDecr[i]];
            predDecr[i] = pred0[sortedDecr[i]];

        }
        
        decValues = decValuesDecr;
        actual = actualDecr;
        pred = predDecr;

        for (int i =0; i < predictions.size();i++)
        {
            int x = (int)(1-actual[i]);
            int y = (int)(1-pred[i]);
            if (actual[i] != pred[i])
            {
                x = (int)(1-pred[i]);
                y = (int)(1-actual[i]);

            }
            m_ConfusionMatrix[x][y] += 1.0;
            if (actual[i]==1)
                numPos = numPos +1;
            else
                numNeg = numNeg +1;
        }

        /*System.out.println("\tCONFUSION MATRIX");
        for (int i=0; i < m_NumClasses;i++)
        {
            for (int j=0;j < m_NumClasses;j++)
            {
                System.out.print(m_ConfusionMatrix[i][j] +" ");
            }
            System.out.println();
        } */
        /*int [] sortedIncr = weka.core.Utils.sort(decValues);
        int[] sorted = new int[sortedIncr.length];

        int num =  sorted.length;
        for (int i = 0;i < num;i++)
        {
            sorted[num-(i+1)] = sortedIncr[i];
        }
        for (int i = 0; i < predictions.size();i++)
        {
            Prediction prediction = (Prediction)predictions.get(sorted[i]);
            System.out.println(prediction.toString());

        }   */
    }
    public static String headerbyDomainString()
    {
        StringBuffer text = new StringBuffer();
        text.append("Domain         #Pos #Neg   Full Sim BS Sim  rAUC    prAUC    Sens.   Spec.   tPrec   fPrec   tFMeas  fFMeas  MCC     FMeas   #Results\n");
        return text.toString();
    }
    public double getNumPositives()
    {
        return numPos;
    }
    public double getNumNegatives()
    {
        return numNeg;

    }
  
    /*public static String summaryByDomainString(List predictionByDomainList)
    {
        Evaluation eval = new Evaluation(predictionByDomainList);
        double rocAUC =  eval.getROCAUC();
        double prAUC =  eval.getPRAUC();

        double sens = eval.recall(0);
        double spec = eval.recall(1);
        double tPrec = eval.precision(0);
        double fPrec = eval.precision(1);
        double mcc = eval.mcc(0);
        double fMeasure = eval.fMeasure();
        double tFMeasure = eval.fMeasure(0);
        double fFMeasure = eval.fMeasure(1);
        String domainName = "";
        double bindingSiteSeqSim = 0.0;
        double fullSeqSim = 0.0;

        int numPos = 0;
        int numNeg = 0;
        for (int i=0;i < predictionByDomainList.size();i++)
        {
            Prediction pred = (Prediction)predictionByDomainList.get(i);
            domainName = pred.name;
            bindingSiteSeqSim = pred.bindingSiteSeqSim;
            fullSeqSim = pred.fullSeqSim;

            if (pred.getActual()==1.0)
                numPos = numPos+1;
            else
                numNeg = numNeg +1;
        }
        StringBuffer text = new StringBuffer();
        String d = Utils.padRight(domainName,15);
        //String n = Utils.padRight(nn,15);
        String pos = Utils.padRight(String.valueOf(numPos),4);
        String neg = Utils.padRight(String.valueOf(numNeg),4);

        text.append(d).append(" ");
        //text.append(n).append(" ");
        text.append(pos).append(" ");
        text.append(neg).append(" ");
        text.append(Utils.doubleToString(fullSeqSim, 8, 3)).append(" ");
        text.append(Utils.doubleToString(bindingSiteSeqSim, 7, 3)).append(" ");
        text.append(Utils.doubleToString(rocAUC, 7, 3)).append(" ");
        text.append(Utils.doubleToString(prAUC, 7, 3)).append(" ");

        text.append(Utils.doubleToString(sens, 7, 3)).append(" ");
        text.append(Utils.doubleToString(spec, 7, 3)).append(" ");
        text.append(Utils.doubleToString(tPrec, 7, 3)).append(" ");
        text.append(Utils.doubleToString(fPrec, 7, 3)).append(" ");
        text.append(Utils.doubleToString(tFMeasure, 7, 3)).append(" ");
        text.append(Utils.doubleToString(fFMeasure, 7, 3)).append(" ");
        text.append(Utils.doubleToString(mcc, 7, 3)).append(" ");
        text.append(Utils.doubleToString(fMeasure, 7, 3)).append(" ");
        text.append(Utils.doubleToString(predictionByDomainList.size(), 7, 4));
        return text.toString();
    }
    */
    // Output predictions by domain
    /*public String summaryByDomain(List predictionList)
    {
        HashMap byDomainMap = new HashMap();
        for (int i=0; i < predictionList.size();i++)
        {
            Prediction pred = (Prediction)predictionList.get(i);
            List byDomainPredictionList = (List)byDomainMap.get(pred.name);
            if (byDomainPredictionList == null)
                byDomainPredictionList = new ArrayList();
            byDomainPredictionList.add(pred);
            byDomainMap.put(pred.name,byDomainPredictionList);
        }
        Set keys = byDomainMap.keySet();
        List keysList = new ArrayList(keys);
        Collections.sort(keysList);
        String summaryString = headerbyDomainString();
        for (int i=0; i < keysList.size();i++)
        {
            String domainName = (String)keysList.get(i);
            List predictionByDomainList = (List)byDomainMap.get(domainName);
            String output = summaryByDomainString(predictionByDomainList);
            summaryString = summaryString+output +"\n";
        }
        return summaryString;
    }
    */
    /**
     * generates an instance out of the given data
     *
     * @param tc the statistics
     * @param prob the probability
     * @return the generated instance
     */
    private Instance makeInstance(TwoClassStats tc, double prob, double actual) {

        int count = 0;
        double [] vals = new double[12];
        vals[count++] = tc.getTruePositive();
        vals[count++] = tc.getFalseNegative();
        vals[count++] = tc.getFalsePositive();
        vals[count++] = tc.getTrueNegative();
        vals[count++] = tc.getFalsePositiveRate();
        vals[count++] = tc.getTruePositiveRate();
        vals[count++] = tc.getPrecision();
        vals[count++] = tc.getRecall();
        vals[count++] = tc.getFallout();
        vals[count++] = tc.getFMeasure();
        vals[count++] = actual;
        vals[count++] = prob;
        return new Instance(1.0, vals);
    }

    /**
     * generates the header
     *
     * @return the header
     */
    public static Instances makeHeader() {

        FastVector fv = new FastVector();
        fv.addElement(new Attribute(TRUE_POS_NAME));
        fv.addElement(new Attribute(FALSE_NEG_NAME));
        fv.addElement(new Attribute(FALSE_POS_NAME));
        fv.addElement(new Attribute(TRUE_NEG_NAME));
        fv.addElement(new Attribute(FP_RATE_NAME));
        fv.addElement(new Attribute(TP_RATE_NAME));
        fv.addElement(new Attribute(PRECISION_NAME));
        fv.addElement(new Attribute(RECALL_NAME));
        fv.addElement(new Attribute(FALLOUT_NAME));
        fv.addElement(new Attribute(FMEASURE_NAME));
        fv.addElement(new Attribute(ACTUAL_NAME));
        fv.addElement(new Attribute(THRESHOLD_NAME));
        return new Instances(RELATION_NAME, fv, 100);
    }

    public Instances getCurve(int classIndex )
    {
        // classindex is the class you want to test 0.0 = neg class, 1.0 = pos class
        double totPos = 0, totNeg = 0;
        // Get distribution of positive/negatives
        for (int i = 0; i < decValues.length; i++)
        {
            if (actual[i] == classIndex)
            {
                totPos = totPos + 1;
            }
            else
            {
                totNeg = totNeg + 1;
            }
        }
        // sort in increasing order
        int [] sortedIncr = weka.core.Utils.sort(decValues);
        int[] sorted = new int[sortedIncr.length];

        int num =  sorted.length;
        for (int i = 0;i < num;i++)
        {
            sorted[num-(i+1)] = sortedIncr[i];
        }
        if (classIndex ==0)
            sorted = sortedIncr;
        double threshold = 0;
        double TP = 0;
        double FP = 0;
        double TN = 0;
        double FN = 0;
        Instances insts = makeHeader();

        for (int i = 0; i<sorted.length;i++)
        {
            double decValue = decValues[sorted[i]];
            // True case
            // if actual == 1.0 then true positive
            if (actual[sorted[i]]==classIndex)
            {
                TP= TP+1;
            }
            else
            {
                // False case
                FP = FP+1;
            }
            TN = totNeg - FP;
            FN = totPos - TP;

            threshold = decValue;

            TwoClassStats tc = new TwoClassStats(TP, FP, TN, FN);
            insts.add(makeInstance(tc, threshold, actual[sorted[i]]));
            //System.out.println(threshold + "\t" + TP + "\t" + FP);

        } // for

        return insts;

    }

    public double getROCAUC()
    {
        Confusion c = AUCCalculator.readArrays(actual, decValues);
        double rocAUC = c.calculateAUCROC();
        return rocAUC;
    }
    public double getPRAUC()
    {
        Confusion c = AUCCalculator.readArrays(actual, decValues);
        double prAUC = c.calculateAUCPR(0);
        return prAUC;
    }

    public double mcc(int classIndex)
    {
        double TP = numTruePositives(classIndex);
        double TN = numTrueNegatives(classIndex);
        double FN = numFalseNegatives(classIndex);
        double FP = numFalsePositives(classIndex);

        double mcc = 0;
        double denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN);
        if (denom ==0)
            mcc = 0;
        else
            mcc = ((TP*TN) - (FP*FN)) / Math.sqrt(denom);
        return mcc;
    }


    public String toString()
    {

        double[] rocAUCs = new double[2];
        // Positive AUC
        rocAUCs[0] = getROCAUC();
        rocAUCs[1] = getROCAUC();

        double[] prAUCs = new double[2];

        prAUCs[0] = getPRAUC();
        prAUCs[1] = getPRAUC();
        StringBuffer text = new StringBuffer(
                "        T       F     T Rate    F Rate"
                        + "   Precision   Recall"
                        + "   F-Measure   rAUC       prAUC      MCC   Class\n");
        for(int i = 0; i < m_NumClasses ; i++) {
            if (i == 0)
                text.append("Pos ");
            else
                text.append("Neg ");

            text.append(Utils.doubleToString(numTruePositives(i), 5, 0))
                    .append("   ");
            text.append(Utils.doubleToString(numFalsePositives(i), 5, 0))
                    .append("   ");
            text.append(Utils.doubleToString(truePositiveRate(i), 7, 3))
                    .append("   ");
            text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
                    .append("    ");
            text.append(Utils.doubleToString(precision(i), 7, 3))
                    .append("   ");
            text.append(Utils.doubleToString(recall(i), 7, 3))
                    .append("   ");
            text.append(Utils.doubleToString(fMeasure(i), 7, 3))
                    .append("    ");
            text.append(Utils.doubleToString(rocAUCs[i], 7, 3))
                    .append("    ");
            text.append(Utils.doubleToString(prAUCs[i], 7, 3))
                    .append("    ");
            text.append(Utils.doubleToString(mcc(i), 7, 3))
                    .append("    ");
            text.append(m_ClassNames[i]).append('\n');
        }
        String confusionMatrixString = "";
        try
        {
            confusionMatrixString = toMatrixString();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
            confusionMatrixString = "";
        }
        return text.toString() + "\n" + confusionMatrixString +"\n";
    }

    public static XYPlot getROCPlot(List instList, List aucList, List aucLabelList, int classIndex)
    {
        Color[] cols = PDZSVMUtils.randomColors(instList.size());

        System.out.println("\tPloting ROC curve " + instList.size() + "...");
        XYPlot plot = null;
        for (int i = 0; i < instList.size();i++)
        {
            String legendLabel = (String)aucLabelList.get(i);
            double auc = (Double)aucList.get(i);

            if (aucList != null) legendLabel = legendLabel + ": " + Utils.roundDouble(auc,2);//df.format(AUCs[i]);

            XYSeries series = new XYSeries(legendLabel);

            //System.out.println("Color: " + cols[i].toString()) ;
            Instances inst = (Instances)instList.get(i);
            inst.sort(inst.attribute(TP_RATE_NAME).index());

            if (inst == null) series.add(0,0);
            else
            {
                int tpInd = inst.attribute(TP_RATE_NAME).index();
                int fpInd = inst.attribute(FP_RATE_NAME).index();

                double [] tpVals1 = inst.attributeToDoubleArray(tpInd);
                double [] fpVals1 = inst.attributeToDoubleArray(fpInd);

                double[] tpVals = new double[tpVals1.length];
                double[] fpVals = new double[fpVals1.length];

                if (classIndex == 0)
                {
                    int j = 0;
                    for (int ii =tpVals.length-1;ii>=0;ii--)
                    {
                        tpVals[j] = tpVals1[ii];
                        fpVals[j] = fpVals1[ii];
                        j = j+1;
                    }
                }
                else
                {
                    tpVals = tpVals1;
                    fpVals= fpVals1;
                }

                for (int jj =0; jj < tpVals.length;jj++)
                {

                    //System.out.println(fpVals[jj]+"," + tpVals[jj]);

                    series.add(fpVals[jj],tpVals[jj]);
                }

            }
            XYSeriesCollection dataset = new XYSeriesCollection(series);

            XYItemRenderer renderer = new StandardXYItemRenderer();
            //System.out.println(cols[i].toString());
            renderer.setSeriesPaint(i,cols[i]);
            if (i ==0)
            {
                plot = new XYPlot(dataset, new NumberAxis("FPR"),new NumberAxis("TPR"),renderer);

            }
            else
            {
                plot.setDataset(i,dataset);
                plot.setRenderer(i,renderer);
            }


        }
        return plot;

    }

    public static XYPlot getPRPlot(List instList, List aucList, List aucLabelList, int classIndex)
    {
        System.out.println("\tPlotting PR curve" + instList.size() + "...");
        XYPlot plot = null;
        Color[] cols = PDZSVMUtils.randomColors(instList.size());
        for (int i = 0; i < instList.size();i++)
        {
            XYSeries series = null;
            if (aucLabelList!= null)
            {
                String legendLabel = (String) aucLabelList.get(i);
                double auc = (Double)aucList.get(i);

                if (aucList != null) legendLabel = legendLabel + ": " + Utils.roundDouble(auc,2);//df.format(AUCs[i]);

                series = new XYSeries(legendLabel);
            }
            else
            {
                series = new XYSeries("");

            }
            Instances inst = (Instances)instList.get(i);
            inst.sort(inst.attribute(PRECISION_NAME).index());

            if (inst == null) series.add(0,0);
            else
            {
                int tpInd = inst.attribute(PRECISION_NAME).index();
                int fpInd = inst.attribute(RECALL_NAME).index();

                double [] tpVals1 = inst.attributeToDoubleArray(tpInd);
                double [] fpVals1 = inst.attributeToDoubleArray(fpInd);

                double[] tpVals = new double[tpVals1.length];
                double[] fpVals = new double[fpVals1.length];

                if (classIndex == 1)
                {
                    int j = 0;
                    for (int ii =tpVals.length-1;ii>=0;ii--)
                    {
                        tpVals[j] = tpVals1[ii];
                        fpVals[j] = fpVals1[ii];
                        j = j+1;
                    }
                }
                else
                {
                    tpVals = tpVals1;
                    fpVals= fpVals1;
                }
                

                for (int jj =0; jj < tpVals.length;jj++)
                {
                    series.add(fpVals[jj],tpVals[jj]);
                }
            }
            XYSeriesCollection dataset = new XYSeriesCollection(series);

            XYItemRenderer renderer = new StandardXYItemRenderer();
            renderer.setSeriesPaint(i,cols[i]);
            if (i ==0)
            {
                plot = new XYPlot(dataset, new NumberAxis("RECALL"),new NumberAxis("PRECISION"),renderer);

            }
            else
            {
                plot.setDataset(i,dataset);
                plot.setRenderer(i,renderer);
            }


        }

        return plot;
    }

    /**
     * Method for generating indices for the confusion matrix.
     *
     * @param num 	integer to format
     * @param IDChars	the characters to use
     * @param IDWidth	the width of the entry
     * @return 		the formatted integer as a string
     */
    protected String num2ShortID(int num, char[] IDChars, int IDWidth) {

        char ID [] = new char [IDWidth];
        int i;

        for(i = IDWidth - 1; i >=0; i--) {
            ID[i] = IDChars[num % IDChars.length];
            num = num / IDChars.length - 1;
            if (num < 0) {
                break;
            }
        }
        for(i--; i >= 0; i--) {
            ID[i] = ' ';
        }

        return new String(ID);
    }


    /**
     * Outputs the performance statistics as a classification confusion
     * matrix. For each class value, shows the distribution of
     * predicted class values.
     *
     * @return the confusion matrix as a String
     * @throws Exception if the class is numeric
     */
    public String toMatrixString() throws Exception {
        String title = "Confusion Matrix:\nActual  ";
        StringBuffer text = new StringBuffer();
        char [] IDChars = {'Y','N','c','d','e','f','g','h','i','j',
                'k','l','m','n','o','p','q','r','s','t',
                'u','v','w','x','y','z'};
        int IDWidth;
        boolean fractional = false;

        // Find the maximum value in the matrix
        // and check for fractional display requirement
        double maxval = 0;
        for(int i = 0; i < m_NumClasses; i++) {
            for(int j = 0; j < m_NumClasses; j++) {
                double current = m_ConfusionMatrix[i][j];
                if (current < 0) {
                    current *= -10;
                }
                if (current > maxval) {
                    maxval = current;
                }
                double fract = current - Math.rint(current);
                if (!fractional
                        && ((Math.log(fract) / Math.log(10)) >= -2)) {
                    fractional = true;
                }
            }
        }

        IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)
                + (fractional ? 3 : 0)),
                (int)(Math.log(m_NumClasses) /
                        Math.log(IDChars.length)));
        text.append(title).append("\n");
        for(int i = 0; i < m_NumClasses; i++) {
            if (fractional) {
                text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
                        .append("   ");
            } else {
                text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
            }
        }
        text.append(" | Predicted\n");
        for(int i = 0; i < m_NumClasses; i++) {
            for(int j = 0; j < m_NumClasses; j++) {
                text.append(" ").append(
                        Utils.doubleToString(m_ConfusionMatrix[i][j],
                                IDWidth,
                                (fractional ? 2 : 0)));
            }
            //text.append(" | ").append(num2ShortID(i,IDChars,IDWidth)).append(" = ")
            text.append(" | ").append(m_ClassNames[i]).append("\n");
        }
        return text.toString();
    }

    /**
     * Calculate the number of true positives with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified positives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the true positive rate
     */
    public double numTruePositives(int classIndex) {

        return m_ConfusionMatrix[classIndex][classIndex];

    }

    /**
     * Calculate the true positive rate with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified positives
     * ------------------------------
     *       total positives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the true positive rate
     */
    public double truePositiveRate(int classIndex) {

        double total = (numTruePositives(classIndex) + numFalseNegatives(classIndex));
        if (total == 0) {
            return 0;
        }
        //System.out.println("TPR: ");
        //System.out.println(numTruePositives(classIndex) + ", " + numFalseNegatives(classIndex));
        return numTruePositives(classIndex) / total;
    }

    /**
     * Calculate the number of true negatives with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified negatives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the true positive rate
     */
    public double numTrueNegatives(int classIndex) {

        return m_ConfusionMatrix[1-classIndex][1-classIndex];

    }

    /**
     * Calculate the true negative rate with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified negatives
     * ------------------------------
     *       total negatives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the true positive rate
     */
    public double trueNegativeRate(int classIndex) {
        double correct = 0, total = 0;
        for (int i = 0; i < m_NumClasses; i++) {
            if (i != classIndex) {
                for (int j = 0; j < m_NumClasses; j++) {
                    if (j != classIndex) {
                        correct += m_ConfusionMatrix[i][j];
                    }
                    total += m_ConfusionMatrix[i][j];
                }
            }
        }
        if (total == 0) {
            return 0;
        }
        return correct / total;

    }

    /**
     * Calculate number of false positives with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * incorrectly classified negatives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the false positive rate
     */
    public double numFalsePositives(int classIndex) {

        return m_ConfusionMatrix[classIndex][1-classIndex];

    }

    /**
     * Calculate the false positive rate with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * incorrectly classified negatives
     * --------------------------------
     *        total negatives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the false positive rate
     */
    public double falsePositiveRate(int classIndex) {

        double total = (numTrueNegatives(classIndex) + numFalsePositives(classIndex));
        if (total == 0) {
            return 0;
        }

        //System.out.println("FPR: ");
        //System.out.println(numFalsePositives(classIndex) + ", " + numTrueNegatives(classIndex));
        return numFalsePositives(classIndex)/ total;

    }

    /**
     * Calculate number of false negatives with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * incorrectly classified positives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the false positive rate
     */
    public double numFalseNegatives(int classIndex) {

        return m_ConfusionMatrix[1-classIndex][classIndex];

    }

    /**
     * Calculate the false negative rate with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * incorrectly classified positives
     * --------------------------------
     *        total positives
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the false positive rate
     */
    public double falseNegativeRate(int classIndex) {

        double total = numTruePositives(classIndex) + numFalseNegatives(classIndex);
        if (total == 0) {
            return 0;
        }
        return numFalseNegatives(classIndex) / total;
    }

    /**
     * Calculate the recall with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified positives
     * ------------------------------
     *       total positives
     * </pre><p/>
     * (Which is also the same as the truePositiveRate.)
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the recall
     */
    public double recall(int classIndex) {

        return truePositiveRate(classIndex);
    }

    /**
     * Calculate the precision with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * correctly classified positives
     * ------------------------------
     *  total predicted as positive
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the precision
     */
    public double precision(int classIndex) {

        double total = (numTruePositives(classIndex) + numFalsePositives(classIndex));
        if (total == 0) {
            return 0;
        }

        return numTruePositives(classIndex) / total;
    }

    /**
     * Calculate the F-Measure with respect to a particular class.
     * This is defined as<p/>
     * <pre>
     * 2 * recall * precision
     * ----------------------
     *   recall + precision
     * </pre>
     *
     * @param classIndex the index of the class to consider as "positive"
     * @return the F-Measure
     */
    public double fMeasure(int classIndex) {

        double precision = precision(classIndex);
        double recall = recall(classIndex);
        if ((precision + recall) == 0) {
            return 0;
        }
        return 2 * precision * recall / (precision + recall);
    }

    public double fMeasure()
    {
        double denom = fMeasure(0)+fMeasure(1);
        if (denom == 0)
            return 0;
        else
            return (2*fMeasure(0)*fMeasure(1))/(fMeasure(0)+fMeasure(1));

    }

   
    public static void main(String[] args)
    {
        List predictions = new ArrayList();

        Prediction pred1 = new Prediction(1.0,1,0.9);
        Prediction pred2 = new Prediction(1.0,1,0.8);
        Prediction pred3 = new Prediction(1.0,0,0.7);
        Prediction pred4 = new Prediction(1.0,1,0.6);
        Prediction pred5 = new Prediction(1.0,1,0.55);
        Prediction pred6 = new Prediction(1.0,1,0.54);
        Prediction pred7 = new Prediction(1.0,0,0.53);
        Prediction pred8 = new Prediction(1.0,0,0.52);
        Prediction pred9 = new Prediction(1.0,1,0.51);
        Prediction pred10 = new Prediction(1.0,0,0.505);
        Prediction pred11 = new Prediction(1.0,1,0.4);
        Prediction pred12 = new Prediction(1.0,0,0.39);
        Prediction pred13 = new Prediction(1.0,1,0.38);
        Prediction pred14 = new Prediction(1.0,0,0.37);
        Prediction pred15 = new Prediction(1.0,0,0.36);
        Prediction pred16 = new Prediction(1.0,0,0.35);
        Prediction pred17 = new Prediction(1.0,1,0.34);
        Prediction pred18 = new Prediction(1.0,0,0.33);
        Prediction pred19 = new Prediction(1.0,1,0.3);
        Prediction pred20 = new Prediction(1.0,0,0.1);

        // RANDOM
        /*Prediction pred1 = new Prediction(0.0,1.0,0.5);
        Prediction pred2 = new Prediction(1.0,0.0,0.6);
        Prediction pred3 = new Prediction(0.0,1.0,0.1);
        Prediction pred4 = new Prediction(1.0,0.0,0.2);
        Prediction pred5 = new Prediction(0.0,1.0,0.7);
        Prediction pred6 = new Prediction(1.0,0.0,0.8);
        Prediction pred7 = new Prediction(0.0,1.0,0.0);
        Prediction pred8 = new Prediction(1.0,0.0,0.2);
        Prediction pred9 = new Prediction(0.0,1.0,0.1);
        Prediction pred10 = new Prediction(1.0,0.0,0.5);
        Prediction pred11 = new Prediction(0.0,1.0,0.4);
        Prediction pred12 = new Prediction(1.0,0.0,0.9);
        Prediction pred13 = new Prediction(0.0,1.0,0.3);
        Prediction pred14 = new Prediction(1.0,0.0,0.7);
        Prediction pred15 = new Prediction(0.0,1.0,0.3);
        Prediction pred16 = new Prediction(1.0,0.0,0.15);
        Prediction pred17 = new Prediction(0.0,1.0,0.5);
        Prediction pred18 = new Prediction(1.0,0.0,0.63);
        Prediction pred19 = new Prediction(0.0,1.0,0.9);
        Prediction pred20 = new Prediction(1.0,0.0,0.1);
         */


        predictions.add(pred1);
        predictions.add(pred2);
        predictions.add(pred3);

        predictions.add(pred4);
        predictions.add(pred5);
        predictions.add(pred6);

        predictions.add(pred7);
        predictions.add(pred8);
        predictions.add(pred9);
        predictions.add(pred10);
        predictions.add(pred11);
        predictions.add(pred12);
        predictions.add(pred13);
        predictions.add(pred14);
        predictions.add(pred15);
        predictions.add(pred16);
        
        predictions.add(pred17);
        predictions.add(pred18);
        predictions.add(pred19);
        predictions.add(pred20);

        Evaluation eval = new Evaluation(predictions);
        System.out.println("Predictions:");
        for (int i = 0; i < predictions.size();i++)
        {
            Prediction pred = (Prediction)predictions.get(i);
            System.out.println(pred.getPrediction() + "\t" + pred.getActual() + "\t" + pred.getDecValue());
        }
        int classIndex = 1;

        Instances inst = eval.getCurve(classIndex);
        //System.out.println(R);
        List predList = new ArrayList();
        predList.add(inst);
        List aucLabelList= new ArrayList();
        double rocAUC = eval.getROCAUC();
        List aucList= new ArrayList();
        aucList.add(rocAUC);
        aucLabelList.add("Test");

        double prAUC = eval.getPRAUC();
        System.out.println(prAUC);
        List prAucList= new ArrayList();
        prAucList.add(prAUC);


        System.out.println(eval.toString()) ;
        XYPlot plot = Evaluation.getROCPlot(predList,aucList, aucLabelList,classIndex);
        XYPlot prPlot = Evaluation.getPRPlot(predList, prAucList,aucLabelList,classIndex);

        System.out.println(rocAUC);

        JFreeChart chart = new JFreeChart("ROC", JFreeChart.DEFAULT_TITLE_FONT, plot, true);
        ChartPanel chartPanel = new ChartPanel(chart, true, true, true, true, true);
        chartPanel.setPreferredSize(new java.awt.Dimension(250, 250));

        JFreeChart prchart = new JFreeChart("PR", JFreeChart.DEFAULT_TITLE_FONT, prPlot, true);
                ChartPanel prChartPanel = new ChartPanel(prchart, true, true, true, true, true);
                prChartPanel.setPreferredSize(new java.awt.Dimension(250, 250));

        JFrame f = new JFrame("");
        f.setLayout(new FlowLayout());
        f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        f.add(chartPanel);
        f.add(prChartPanel);

        f.pack();
        f.setVisible(true);


    }

}