/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration.attribute;

import java.util.ArrayList;
import java.util.Collection;
import no.uib.cipr.matrix.DenseVector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.actions.ComputeEnrichment;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.attribute.IAttributeScorer;
import org.genemania.engine.core.utils.ObjectSelector;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.Vector;
import org.genemania.exception.ApplicationException;

public class QueryEnrichedAttributeScorer
implements IAttributeScorer {
    private static Logger logger = Logger.getLogger(QueryEnrichedAttributeScorer.class);
    private static int MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE = 2;
    DataCache cache;
    no.uib.cipr.matrix.Vector labels;
    int minQueryGenesPerAttribute;

    public QueryEnrichedAttributeScorer(DataCache cache, no.uib.cipr.matrix.Vector labels, int minQueryGenesPerAttribute) {
        this.cache = cache;
        this.labels = labels;
        this.minQueryGenesPerAttribute = minQueryGenesPerAttribute;
    }

    void logAttributeCounts(String namespace, long organismId, long groupId) throws ApplicationException {
        AttributeData attributeSet = this.cache.getAttributeData(namespace, organismId, groupId);
        Matrix data = attributeSet.getData();
        Vector colSums = data.columnSums();
        for (int i = 0; i < colSums.getSize(); ++i) {
            if (!(colSums.get(i) > 0.0)) continue;
            logger.debug((Object)String.format("attribute %d has col sum %f", i, colSums.get(i)));
        }
    }

    @Override
    public ObjectSelector<Feature> scoreAttributes(String namespace, long organismId, long attributeGroupId) throws ApplicationException {
        AttributeData attributeData = this.cache.getAttributeData(namespace, organismId, attributeGroupId);
        DenseVector pvals = new DenseVector(attributeData.getData().numCols());
        DenseVector selection = this.computeSelectionMask(organismId, this.labels);
        Matrix annotations = attributeData.getData();
        int numCategories = annotations.numCols();
        int numGenes = annotations.numRows();
        DenseVector backgroundCounts = new DenseVector(numCategories);
        annotations.columnSums(backgroundCounts.getData());
        DenseVector sampleCounts = new DenseVector(numCategories);
        annotations.transMult(selection.getData(), sampleCounts.getData());
        this.computePVals(annotations, numCategories, numGenes, backgroundCounts, sampleCounts, selection, pvals);
        ObjectSelector<Feature> list = this.buildList(namespace, organismId, attributeGroupId, pvals, sampleCounts, backgroundCounts);
        return list;
    }

    private ObjectSelector<Feature> buildList(String namespace, long organismId, long attributeGroupId, DenseVector pvals, DenseVector sampleCounts, DenseVector backgroundCounts) throws ApplicationException {
        ObjectSelector<Feature> list = new ObjectSelector<Feature>();
        AttributeGroups groups = this.cache.getAttributeGroups(namespace, organismId);
        ArrayList<Long> attributeIds = groups.getAttributeGroups().get(attributeGroupId);
        for (int i = 0; i < attributeIds.size(); ++i) {
            long attributeId = attributeIds.get(i);
            double sampleCount = sampleCounts.get(i);
            double backgroundCount = backgroundCounts.get(i);
            double pval = pvals.get(i);
            if (!(pval < 0.99999) || !(sampleCount >= (double)this.minQueryGenesPerAttribute) || !(backgroundCount > (double)MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE)) continue;
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature, pvals.get(i));
        }
        logger.debug((Object)String.format("ranked %d attributes by enrichment based on query list", list.size()));
        return list;
    }

    public static FeatureList buildFeatureList(AttributeGroups attributeGroups) {
        FeatureList features = new FeatureList();
        for (long groupId : attributeGroups.getAttributeGroups().keySet()) {
            ArrayList<Long> attributeIds = attributeGroups.getAttributeGroups().get(groupId);
            for (long attributeId : attributeIds) {
                Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, groupId, attributeId);
                features.add(feature);
            }
        }
        return features;
    }

    DenseVector computeSelectionMask(long organismId, no.uib.cipr.matrix.Vector labels) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(organismId);
        DenseVector selection = new DenseVector(nodeIds.getNodeIds().length);
        for (int i = 0; i < labels.size(); ++i) {
            if (labels.get(i) != 1.0) continue;
            selection.set(i, 1.0);
        }
        return selection;
    }

    DenseVector computeSelectionMask(long organismId, Collection<Long> nodes) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(organismId);
        DenseVector selection = new DenseVector(nodeIds.getNodeIds().length);
        for (long id : nodes) {
            int index = nodeIds.getIndexForId(id);
            selection.set(index, 1.0);
        }
        return selection;
    }

    void computePVals(Matrix annotations, int numCategories, int numGenes, DenseVector backgroundCounts, DenseVector sampleCounts, DenseVector selection, DenseVector pvals) {
        int N = numGenes;
        int n = MatrixUtils.countMatches((no.uib.cipr.matrix.Vector)selection, 1.0);
        for (int category = 0; category < numCategories; ++category) {
            long x = Math.round(sampleCounts.get(category));
            long k = Math.round(backgroundCounts.get(category));
            double pval = ComputeEnrichment.computeCumulHyperGeo(x, N, n, k);
            pvals.set(category, pval);
        }
    }
}

