package org.baderlab.pdzsvmstruct.data.manager;


import java.util.*;
import java.util.List;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.File;

import org.baderlab.pdzsvmstruct.utils.Constants;
import org.baderlab.brain.*;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.db.HashSequenceDB;
import org.baderlab.pdzsvmstruct.predictor.pwm.PWM;
import org.baderlab.pdzsvmstruct.utils.PDZSVMUtils;
import org.baderlab.pdzsvmstruct.utils.BindingSiteUtils;
import weka.core.Utils;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Manager for the generation of different types of artificial negatives
 * including random, shuffled, randomly selected and PWM negatives.
 */
public class ArtificialNegativesDataManager
{
    private boolean balance = false;
    public ArtificialNegativesDataManager()
    {
        System.out.println("\n\tInitializing artificial negative data manager...\n");

    }

    public void balance(boolean balanceNegs)
    {
        balance = balanceNegs;
    }
    
    
    private List getNegatives(List refProfileList, ProteinProfile profile, int numPeptideSim)
    {
        SequencePoolManager sm = new SequencePoolManager(refProfileList);
        List sequencePoolList = sm.getSequencePool();

        //System.out.println("\tSequence pool size: "+ sequencePoolList.size());
        PWM pwm = new PWM(profile);
        List sortedSequencePoolList = sm.sortSequencePool(SequencePoolManager.DESC, pwm);

        // Find out the cutoff wrt to the profile passed in
        Collection seqCollection = profile.getSequenceMap();
        List seqList= new ArrayList(seqCollection);
        double cutoff = Double.MAX_VALUE;
        for (int ii =0; ii < seqList.size();ii++)
        {
            Sequence seq = (Sequence)seqList.get(ii);
            String seqString = seq.seqString();
            double score = pwm.score(seqString);
            if (score < cutoff) cutoff= score;
        }
        // Scan all peptides in the sequence pool for those < cutoff and not too similar to the ones already found as defined by isLike
        List reducedSequencePoolList  = new ArrayList();
        HashMap scoreToSeqMap = new HashMap();
        for (int ii=0; ii < sortedSequencePoolList.size();ii++)
        {
            String seq = (String) sortedSequencePoolList.get(ii);
            if (!SequencePoolManager.isLike(reducedSequencePoolList, seq, numPeptideSim))
            {
                double score = pwm.score(seq);
                if (score < cutoff)
                {
                    reducedSequencePoolList.add(seq);
                    List scoreSeqList = (List)scoreToSeqMap.get(score);
                    if (scoreSeqList== null)
                    {
                        scoreSeqList = new ArrayList();
                    }
                    scoreSeqList.add(seq);
                    scoreToSeqMap.put(score,scoreSeqList);

                }
            }
        }
        if (balance)
        {
            int numToGet = seqList.size();
            System.out.println("\n\tBalancing negatives = " +numToGet+ "+," + reducedSequencePoolList.size() + "-...");
            if (numToGet > reducedSequencePoolList.size())
            {
                System.out.println("\tFewer negatives than positives. Skipping...");
            }
            else
            {
                Set scoreKeys = scoreToSeqMap.keySet();
                List scoreList = new ArrayList(scoreKeys);
                List scoreSeqList = new ArrayList();
                Collections.sort(scoreList);
                int ix = 0;
                while(scoreSeqList.size()<numToGet)
                {
                    Double scoreKey = (Double) scoreList.get(ix);
                    List mappedSeqList = (List)scoreToSeqMap.get(scoreKey);
                    scoreSeqList.addAll(mappedSeqList);
                    //System.out.println("\tNum added so far: " + scoreSeqList.size() + ": " + scoreKey);
                    ix = ix+1;
                }
                reducedSequencePoolList = scoreSeqList;
            }
        }

        System.out.println("\t=== " +profile.getName() +" (" +reducedSequencePoolList.size() +" of "+sequencePoolList.size()+ ") ===");

        return reducedSequencePoolList;
    }

    public void addProteomeScanNegatives(List negProfileList, List domainNameList)
    {
        System.out.println("Adding negatives...");
        List negsAddedList = new ArrayList();
        for (int i=0; i < negProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negProfileList.get(i);
            Map negSeqMap = negProfile.getSequenceHashMap();

            String domainName = negProfile.getName();
            int before = negSeqMap.size();

            String organism = PDZSVMUtils.organismLongToShortForm(negProfile.getOrganism());
            if (!domainNameList.contains(negProfile.getName()))
                continue;
            List addedList = new ArrayList();
            try
            {
                // CLASS 2
                String filename = DataFileManager.DATA_ROOT_DIR+"/Data/PredictedNegatives/";

                String dir = "Human";
                if (organism.equals(Constants.MOUSE))
                {
                    dir = "Mouse";
                }
                if (dir.equals("Mouse"))
                {
                    if (domainName.equals("LRRC7-1") || domainName.equals("PDZK1-1") ||
                            domainName.equals("SHANK3-1"))
                    {
                        filename = filename+domainName + "-M.neg.predictions.txt";
                    }
                    else
                    {
                        filename = filename+domainName + ".neg.predictions.txt";

                    }
                }
                else
                {
                    filename = filename+domainName + ".neg.predictions.txt";
                }
                BufferedReader br = new BufferedReader(new FileReader(new File(filename)));
                String line = "";
                List peptideList = new ArrayList();
                while((line=br.readLine())!=null)
                {
                    String[] splitLine = line.split("\t");
                    peptideList.add(splitLine[1]);
                }
                br.close();

                int count = 0;
                for (int j = 0; j < peptideList.size();j++)
                {
                    if (j % 10 ==0)
                    {
                        String peptide = (String)peptideList.get(j);
                        Sequence prot = ProteinTools.createProteinSequence(peptide, domainName + j);
                        negSeqMap.put(domainName + j,prot);
                        addedList.add(peptide);
                        negsAddedList.add(peptide);
                        count = count+1;
                        if (count == 10)
                            break;
                    }
                }

            }
            catch(Exception e)
            {
                System.out.println("Exception: " + e);
                e.printStackTrace();
            }
            System.out.println(domainName + ": Adding negatives, " + before +"->" + negSeqMap.size());

        }
    }


    public List getPWMNegatives(List profileList)
    {
        return getPWMNegatives(profileList, Constants.NUM_RED_PEPTIDES);
    }


    public List getPWMNegatives(List profileList, int numRedPeptides)
    {
        System.out.println("\tGetting PWM negatives...");
        System.out.println("\tNum similar peptides: " + numRedPeptides);

        List artNegProfileList = new ArrayList();
        int totSeq =0;
        for (int i = 0; i < profileList.size();i++)
        {
            ProteinProfile trainProfile = (ProteinProfile)profileList.get(i);

            List artNegSequences = getNegatives(profileList, trainProfile, numRedPeptides);

            HashSequenceDB seqDB = new HashSequenceDB();

            String domainName = trainProfile.getName();
            for (int j = 0; j < artNegSequences.size();j++)
            {
                String seq = (String) artNegSequences.get(j);

                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(seq, domainName + "-" + j);
                    seqDB.addSequence(prot);
                    totSeq = totSeq +1;
                    //System.out.println(" Added: " + score +","+seq);
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                }

            }
            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, seqDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        }
        System.out.println("\tTotal number sequence generated: " + totSeq);

        System.out.println("\tPWM neg stats:");
        getStats(profileList, artNegProfileList);
        System.out.println("\t#Pos profiles: " + profileList.size());
        System.out.println("\t#Neg profiles: " + artNegProfileList.size());
        System.out.println();
        return artNegProfileList;
    }
    private void getStats(List trainProfileList, List artNegProfileList)
    {
        HashMap trainProfileMap = PDZSVMUtils.profileListToHashMap(trainProfileList);
        double mean = 0.0;
        double min = Double.MAX_VALUE;
        double max = Double.MIN_VALUE;
        int total = 0;
        int ix = 0;
        for (int i = 0; i < artNegProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)artNegProfileList.get(i);
            ProteinProfile trainProfile = (ProteinProfile)trainProfileMap.get(profile.getName());
            if (trainProfile == null)
                continue;
            PWM pwm = new PWM(trainProfile);
            Collection seqCollection = profile.getSequenceMap();
            Iterator it = seqCollection.iterator();
            while(it.hasNext())
            {
                Sequence sequence = (Sequence)it.next();
                String seq = sequence.seqString();
                double score = pwm.score(seq);
                mean = mean + score;
                if(score < min)
                    min = score;
                if (score > max)
                    max = score;

            }
            total = total + seqCollection.size();
            ix = ix +1;
        }
        System.out.println("\tMin, max, mean: [" +Utils.doubleToString(min,7,3) + "," +Utils.doubleToString(max,7,3) + "," + Utils.doubleToString(mean/total,7,3) + "]");

    }

} // end class
