package org.baderlab.brain.sequencelogo;

import java.util.Iterator;
import java.util.List;

import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dist.DistributionTrainerContext;
import org.biojava.bio.symbol.Alignment;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalAlphabetException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.bio.symbol.SymbolList;

public class SequenceLogoDistributionTools {

	  /**
	   * Creates an array of distributions, one for each column of the alignment.
	   *
	   * @throws IllegalAlphabetException if all sequences don't use the same alphabet
	   * @param a the <code>Alignment </code>to build the <code>Distribution[]</code> over.
	   * @param countGaps if true gaps will be included in the distributions
	   * @param nullWeight the number of pseudo counts to add to each distribution,
	   * pseudo counts will not affect gaps, no gaps, no gap counts.
	   * @return a <code>Distribution[]</code> where each member of the array is a
	   * <code>Distribution </code>of the <code>Symbols </code>found at that position
	   * of the <code>Alignment </code>.
	   * @since 1.2
	   */
	  public static final Distribution[] distOverAlignment(Alignment a,
	                                                 boolean countGaps,
	                                                 double nullWeight)
	  throws IllegalAlphabetException {

	    List seqs = a.getLabels();

	    FiniteAlphabet alpha = (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(0))).getAlphabet();
	    for(int i = 1; i < seqs.size();i++){
	        FiniteAlphabet test = (FiniteAlphabet)((SymbolList)a.symbolListForLabel(seqs.get(i))).getAlphabet();
	        if(test != alpha){
	          throw new IllegalAlphabetException("Cannot Calculate distOverAlignment() for alignments with"+
	          "mixed alphabets");
	        }
	    }

	    Distribution[] pos = new Distribution[a.length()];
	    DistributionTrainerContext dtc = new SequenceLogoDistributionTrainerContext();
	    dtc.setNullModelWeight(nullWeight);

	    double[] adjRatios = null;
	    if(countGaps){
	      adjRatios = new double[a.length()];
	    }

	    try{
	      for(int i = 0; i < a.length(); i++){// For each position
	        double gapCount = 0.0;
	        double totalCount = 0.0;

	        pos[i] = DistributionFactory.DEFAULT.createDistribution(alpha);
	        dtc.registerDistribution(pos[i]);

	        for(Iterator j = seqs.iterator(); j.hasNext();){// of each sequence
	          Object seqLabel = j.next();
	          Symbol s = a.symbolAt(seqLabel,i + 1);

	          /*If this is working over a flexible alignment there is a possibility
	          that s could be null if this Sequence is not really preset in this
	          region of the Alignment. In this case it will be skipped*/
	          if(s == null)
	            continue;

	          Symbol gap = alpha.getGapSymbol();
	          if(countGaps &&
	             s.equals(gap)){
	             gapCount++; totalCount++;
	          }else{
	            dtc.addCount(pos[i],s,1.0);// count the symbol
	            totalCount++;
	          }
	        }

	        if(countGaps){
	          adjRatios[i] = 1.0 - (gapCount / totalCount);
	        }
	      }

	      dtc.train();

	      if(countGaps){//need to adjust counts for gaps
	        for (int i = 0; i < adjRatios.length; i++) {
	          Distribution d = pos[i];
	          for (Iterator iter = ((FiniteAlphabet)d.getAlphabet()).iterator();
	                            iter.hasNext(); ) {
	            Symbol sym = (Symbol)iter.next();
	            d.setWeight(sym, (d.getWeight(sym) * adjRatios[i]));
	          }
	        }
	      }

	    }catch(Exception e){
	      e.printStackTrace(System.err);
	    }
	    return pos;
	  }

}
