package org.baderlab.pdzsvmstruct.data;

import org.baderlab.pdzsvmstruct.data.manager.ArtificialNegativesDataManager;
import org.baderlab.pdzsvmstruct.data.manager.SequencePoolManager;
import org.baderlab.pdzsvmstruct.utils.BindingSiteUtils;
import org.baderlab.brain.ProteinProfile;
import java.util.*;

import org.baderlab.pdzsvmstruct.utils.Constants;
import org.baderlab.pdzsvmstruct.utils.PDZSVMUtils;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.db.HashSequenceDB;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Loads the data
 */
public class DataLoader {
    private List posTrainProfileList = null;
    private List negTrainProfileList= null;
    private List posTestProfileList= null;
    private List negTestProfileList= null;
    private DataRepository dr = DataRepository.getInstance();
    private ArtificialNegativesDataManager am;


    public List getPosTrainProfileList()
    {
        return posTrainProfileList;
    }
    public List getNegTrainProfileList()
    {
        return negTrainProfileList;
    }
    public List getPosTestProfileList()
    {
        return posTestProfileList;
    }
    public List getNegTestProfileList()
    {
        return negTestProfileList;
    }
    public static List filterOutNonGenomicInteractions(List profileList, String organism)
    {
        SequencePoolManager s = new SequencePoolManager(organism);
        List seqPoolList = s.getSequencePool(4);
        List filteredProfileList = new ArrayList();
        int totalNonGenomicRemoved = 0;
        for (int i=0; i < profileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile) profileList.get(i);
            //System.out.println("\tFiltering " + profile.getName());
            Collection seqCollection = profile.getSequenceMap();
            List seqList = new ArrayList(seqCollection);
            List newSeqList = new ArrayList();
            HashSequenceDB seqDB = new HashSequenceDB();
            for (int j = 0; j < seqList.size();j++)
            {
                Sequence seq = (Sequence)seqList.get(j);
                String seqString = seq.seqString();
                String addSeqString = seqString.substring(1,seqString.length());
                if (seqPoolList.contains(addSeqString))
                {
                    try
                    {
                        Sequence addSeq = ProteinTools.createProteinSequence(seqString,seqString+j);
                        seqDB.addSequence(addSeq);
                        newSeqList.add(addSeqString);
                    }
                    catch(Exception e)
                    {
                        System.out.println("Exception: " + e);
                    }
                }

            }
            int numFiltered =  seqList.size() - newSeqList.size();
            if (numFiltered>0)
                System.out.println("\t"+profile.getName() + ": Filtered " + numFiltered + " sequences...");
            totalNonGenomicRemoved = totalNonGenomicRemoved+ numFiltered;

            ProteinProfile filteredProfile = PDZSVMUtils.makeProfile(
                    profile.getName(),
                    profile.getDomainNumber(),
                    profile.getDomainSequence(),
                    profile.getOrganism(),
                    seqDB);
            profile.setExperimentalMethod(Constants.PHAGE_DISPLAY_HIGH_STRING);

            filteredProfileList.add(filteredProfile);
        }
        //System.out.println("\tTotal sequences with X removed: " + totalXRemoved);
        return filteredProfileList;
    }
    public static List filterOutXInteractions(List profileList)
    {
        List filteredProfileList = new ArrayList();
        int totalXRemoved = 0;
        for (int i=0; i < profileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile) profileList.get(i);
            Collection seqCollection = profile.getSequenceMap();
            List seqList = new ArrayList(seqCollection);
            List newSeqList = new ArrayList();
            int numRetained =0;
            HashSequenceDB seqDB = new HashSequenceDB();

            for (int j = 0; j < seqList.size();j++)
            {
                Sequence seq = (Sequence) seqList.get(j);
                String seqString = seq.seqString();

                //if (seqString.charAt(0)!='X')
                if (seqString.indexOf("X")==-1)
                {
                    try
                    {
                        Sequence addSeq = ProteinTools.createProteinSequence(seqString,seqString);
                        seqDB.addSequence(addSeq);
                        numRetained = numRetained +1;
                        newSeqList.add(seq);
                    }
                    catch(Exception e)
                    {
                        System.out.println("Exception: " + e);
                    }
                }
            }
            int numFiltered =  seqList.size() - newSeqList.size();
            if (numFiltered>0)
                System.out.println("\t"+profile.getName() + ": Filtered " + numFiltered + " sequences...");
            totalXRemoved = totalXRemoved+ numFiltered;

            try
            {
                profile.setSequenceMap(newSeqList);
            }
            catch(Exception e)
            {
                System.out.println("Exception: " + e);
            }
            filteredProfileList.add(profile);
        }
        //System.out.println("\tTotal sequences with X removed: " + totalXRemoved);
        return filteredProfileList;
    }

    public int getNumPosTrainInteractions()
    {
        int num= 0;
        if (posTrainProfileList == null)
            return num;
        for (int i=0;i < posTrainProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)posTrainProfileList.get(i);
            num = num + profile.getNumSequences();
        }
        return num;
    }
    public int getNumPosTestInteractions()
    {
        int num= 0;
        if (posTestProfileList == null)
            return num;
        for (int i=0;i < posTestProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)posTestProfileList.get(i);
            num = num + profile.getNumSequences();
        }
        return num;
    }

    public int getNumNegTestInteractions()
    {
        int num= 0;
        if (negTestProfileList == null)
            return num;
        for (int i=0;i < negTestProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)negTestProfileList.get(i);
            num = num + profile.getNumSequences();
        }
        return num;
    }
    public int getNumNegTrainInteractions()
    {
        int num= 0;
        if (negTrainProfileList == null)
            return num;
        for (int i=0;i < negTrainProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)negTrainProfileList.get(i);
            num = num + profile.getNumSequences();
        }
        return num;
    }
    public DataLoader()
    {
        am = new ArtificialNegativesDataManager();
    }

    public static void main(String[] args)
    {
        DataLoader dl = new DataLoader();
        dl.loadMousePDBTrain();
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.PHAGE_DISPLAY);
        dl.loadWormPDBTest(Constants.CHEN_WORM_PDB);
        
        List posTestProfileList = dl.getPosTestProfileList();
        if (posTestProfileList == null) posTestProfileList = new ArrayList();

        List negTestProfileList = dl.getNegTestProfileList();
        if (negTestProfileList == null) negTestProfileList = new ArrayList();

        List posTrainProfileList = dl.getPosTrainProfileList();
        if (posTrainProfileList == null) posTrainProfileList = new ArrayList();


        List negTrainProfileList = dl.getNegTrainProfileList();
        if (negTrainProfileList == null) negTrainProfileList = new ArrayList();

        Data data = new Data();
        data.addRawData(posTrainProfileList,Constants.CLASS_YES);
        data.addRawData(negTrainProfileList,Constants.CLASS_NO);
        HashMap peptideMap = data.getPeptideNumToRawMap();
        data.printSummary();

        Data testdata = new Data();
        testdata.addRawData(posTestProfileList,Constants.CLASS_YES);
        testdata.addRawData(negTestProfileList,Constants.CLASS_NO);
        HashMap testpeptideMap = testdata.getPeptideNumToRawMap();

        System.out.println("TRAIN DOMAINS");
        System.out.println("Num Pos Domains: " + posTrainProfileList.size());
        System.out.println("Num Neg Domains: " + negTrainProfileList.size());
        System.out.println("TRAIN PEPTIDES");

        System.out.println("Num total: " + peptideMap.size());

        System.out.println("TRAIN INTERACTIONS");
        System.out.println("Num Pos Interactions:" + dl.getNumPosTrainInteractions());
        System.out.println("Num Neg Interactions:" + dl.getNumNegTrainInteractions());
        System.out.println("TEST DOMAINS");
        System.out.println("Num Pos Domains: " + posTestProfileList.size());
        System.out.println("Num Neg Domains: " + negTestProfileList.size());
        System.out.println("TEST PEPTIDES");

        System.out.println("Num total: " + testpeptideMap.size());
        System.out.println("TEST INTERACTIONS");
        System.out.println("Num Pos Interactions:" + dl.getNumPosTestInteractions());
        System.out.println("Num Neg Interactions:" + dl.getNumNegTestInteractions());

    }


    public void clearTrain()
    {
        posTrainProfileList = new ArrayList();
        negTrainProfileList = new ArrayList();
    }
    public void clearTest()
    {
        posTestProfileList = new ArrayList();
        negTestProfileList = new ArrayList();

    }
    public void clearAll()
    {
        clearTrain();
        clearTest();
    }

    public void loadMousePDBTrain()
    {
        System.out.println("\n\tLoading mouse (CHEN) protein microarray training data...");
        if (posTrainProfileList == null ||
                posTrainProfileList.isEmpty())
            posTrainProfileList = new ArrayList();
        // Randomly remove positives
        List addPosList = dr.mousePosPMPDBList;
        posTrainProfileList.addAll(addPosList);

        if (negTrainProfileList == null ||
                negTrainProfileList.isEmpty())
            negTrainProfileList = new ArrayList();

        List negToAddList = dr.mouseNegPMPDBList;
        List domainNamesList = new ArrayList();
        domainNamesList.add("A1-SYNTROPHIN-1");
        domainNamesList.add("CHAPSYN-110-2");
        domainNamesList.add("CIPP-3");
        domainNamesList.add("CIPP-5");
        domainNamesList.add("CIPP-8");
        domainNamesList.add("DVL1-1");
        domainNamesList.add("GRIP1-6");
        domainNamesList.add("MAGI-3-1");
        domainNamesList.add("MAGI-3-5");
        domainNamesList.add("MALS2-1");
        domainNamesList.add("MUPP1-5");
        domainNamesList.add("NHERF-2-2");
        domainNamesList.add("OMP25-1");
        domainNamesList.add("PDZK1-1");
        domainNamesList.add("PDZK1-3");
        domainNamesList.add("PSD95-2");
        domainNamesList.add("SAP102-2");
        domainNamesList.add("SCRB1-3");
        domainNamesList.add("SHANK1-1");
        domainNamesList.add("SHANK3-1");
        domainNamesList.add("ZO-1-1");
        
        am.addProteomeScanNegatives(negToAddList, domainNamesList);
        negTrainProfileList.addAll(negToAddList);

    }

    public void loadSidhuHumanPDBTrain(String organism, String method)
    {
        if (posTrainProfileList == null ||
                posTrainProfileList.isEmpty())
            posTrainProfileList = new ArrayList();
        List posToAddList = new ArrayList();
        if (method.equals(Constants.PHAGE_DISPLAY))
        {
            if (organism.equals(Constants.SIDHU_HUMAN_G_PDB))
                posToAddList = dr.humanPosSidhuGenomicPDBList;

        }

        if (posToAddList.size()>0)
            //CLASS 2            
            posToAddList = removeSmallProfiles(posToAddList,3);//, Constants.MIN_NUM_PEPTIDES);

        List negToAddList = new ArrayList();
        if (negTrainProfileList == null ||
                negTrainProfileList.isEmpty())
            negTrainProfileList = new ArrayList();
        am.balance(true);
        negToAddList = am.getPWMNegatives(posToAddList, 3);
        List domainNamesList = new ArrayList();
        domainNamesList.add("DLG2-3");
        domainNamesList.add("DLG1-1");
        domainNamesList.add("DLG1-2");
        domainNamesList.add("DLG1-3");
        domainNamesList.add("DLG3-2");
        domainNamesList.add("DLG4-3");
        domainNamesList.add("DVL2-1");
        domainNamesList.add("INADL-2");
        domainNamesList.add("LRRC7-1");
        domainNamesList.add("MAGI1-5");
        domainNamesList.add("MAGI3-5");
        domainNamesList.add("MPDZ-1");
        domainNamesList.add("MPDZ-3");
        domainNamesList.add("MPDZ-10");
        domainNamesList.add("MPP6-1");
        domainNamesList.add("PDLIM4-1");
        domainNamesList.add("PDZK1-1");
        domainNamesList.add("PSCDBP-1");
        domainNamesList.add("SLC9A3R2-2");
        domainNamesList.add("SNTA1-1");

        am.addProteomeScanNegatives(negToAddList, domainNamesList);

        negTrainProfileList.addAll(negToAddList);
        posTrainProfileList.addAll(posToAddList);
        
    }


    public static List removeSmallProfiles(List profileList, int size)
    {
        List newProfileList = new ArrayList();
        for (int i =0; i < profileList.size();i++)
        {
            ProteinProfile profile= (ProteinProfile)profileList.get(i);
            if (profile.getNumSequences() >= size)// || profile.getName().equals("CASK-1"))
            {
                newProfileList.add(profile);
            }
            else
            {
                System.out.println("\tRemoving profile: " + profile.getName() + "\t" +  profile.getNumSequences());
            }

        }
        return newProfileList;
    }

    public void loadMousePDBTest(String type)
    {
        if (posTestProfileList == null ||
                posTestProfileList.isEmpty())
            posTestProfileList = new ArrayList();

        if (negTestProfileList == null ||
                negTestProfileList.isEmpty())
            negTestProfileList = new ArrayList();
        List addPosProfileList= new ArrayList();
        List addNegProfileList= new ArrayList();

        if (type.equals(Constants.CHEN_MOUSE_ORPHAN_PDB))
        {
            addPosProfileList = dr.mousePosPMOrphanPDBList;
            addNegProfileList = dr.mouseNegPMOrphanPDBList;
        }
        else if (type.equals(Constants.CHEN_MOUSE_PDB))
        {
            addPosProfileList = dr.mousePosPMPDBList;
            addNegProfileList = dr.mouseNegPMPDBList;
        }
        else
        {
            System.out.println("\tDid not load mouse test.  No data for " + type + "...");
        }

        posTestProfileList.addAll(addPosProfileList);
        negTestProfileList.addAll(addNegProfileList);

    }
    public void loadFlyPDBTest()
    {
        posTestProfileList = dr.flyPosPMPDBList;
        negTestProfileList = dr.flyNegPMPDBList;

    }
    public void loadWormPDBTest(String type)
    {
        if (type.equals(Constants.CHEN_WORM_PDB))
        {

            posTestProfileList = dr.wormPosPMPDBList;
            negTestProfileList = dr.wormNegPMPDBList;
        }

    }
}
