/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration;

import java.util.Collection;
import no.uib.cipr.matrix.DenseVector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.actions.ComputeEnrichment;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.AbstractAttributeSelector;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.Vector;
import org.genemania.exception.ApplicationException;

public class QueryEnrichedAttributeSelector
extends AbstractAttributeSelector {
    private static Logger logger = Logger.getLogger(QueryEnrichedAttributeSelector.class);
    private static int MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE = 2;
    DataCache cache;
    no.uib.cipr.matrix.Vector labels;
    int maxSize;
    int minQueryGenesPerAttribute;

    public QueryEnrichedAttributeSelector(DataCache cache, no.uib.cipr.matrix.Vector labels, int maxSize, int minQueryGenesPerAttribute) {
        this.cache = cache;
        this.labels = labels;
        this.maxSize = maxSize;
        this.minQueryGenesPerAttribute = minQueryGenesPerAttribute;
    }

    void logAttributeCounts(String namespace, long organismId, long groupId) throws ApplicationException {
        AttributeData attributeSet = this.cache.getAttributeData(namespace, organismId, groupId);
        Matrix data = attributeSet.getData();
        Vector colSums = data.columnSums();
        for (int i = 0; i < colSums.getSize(); ++i) {
            if (!(colSums.get(i) > 0.0)) continue;
            logger.debug((Object)String.format("attribute %d has col sum %f", i, colSums.get(i)));
        }
    }

    @Override
    public FeatureList selectAttributes(long organismId, long attributeGroupId) throws ApplicationException {
        AttributeData attributeData = this.cache.getAttributeData("CORE", organismId, attributeGroupId);
        DenseVector pvals = new DenseVector(attributeData.getData().numCols());
        DenseVector selection = this.computeSelectionMask(organismId, this.labels);
        Matrix annotations = attributeData.getData();
        int numCategories = annotations.numCols();
        int numGenes = annotations.numRows();
        DenseVector backgroundCounts = new DenseVector(numCategories);
        annotations.columnSums(backgroundCounts.getData());
        DenseVector sampleCounts = new DenseVector(numCategories);
        annotations.transMult(selection.getData(), sampleCounts.getData());
        this.computePVals(annotations, numCategories, numGenes, backgroundCounts, sampleCounts, selection, pvals);
        FeatureList list = this.buildFeatureListForTopAttributes(organismId, attributeGroupId, pvals, sampleCounts, backgroundCounts, this.maxSize);
        return list;
    }

    private FeatureList buildFeatureListForTopAttributes(long organismId, long attributeGroupId, DenseVector pvals, DenseVector sampleCounts, DenseVector backgroundCounts, int maxSize) throws ApplicationException {
        AttributeGroups groups = this.cache.getAttributeGroups("CORE", organismId);
        FeatureList list = new FeatureList();
        DenseVector ranks = pvals.copy();
        MatrixUtils.rank((no.uib.cipr.matrix.Vector)ranks);
        DenseVector ordered = new DenseVector(pvals.size());
        int[] unrank = new int[pvals.size()];
        int i = 0;
        while (i < ranks.size()) {
            int p = (int)Math.round(ranks.get(i)) - 1;
            ordered.set(p, pvals.get(i));
            unrank[p] = i++;
        }
        int max = Math.min(ranks.size(), maxSize);
        for (int i2 = 0; i2 < max && !(ordered.get(i2) >= 0.99999); ++i2) {
            int attributeIndex = unrank[i2];
            double sampleCount = sampleCounts.get(attributeIndex);
            double backgroundCount = backgroundCounts.get(attributeIndex);
            if (!(sampleCount >= (double)this.minQueryGenesPerAttribute) || !(backgroundCount > (double)MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE)) continue;
            long attributeId = groups.getAttributeIdForIndex(attributeGroupId, attributeIndex);
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature);
        }
        logger.debug((Object)String.format("pre-selected %d attributes by enrichment based on query list", list.size()));
        return list;
    }

    DenseVector computeSelectionMask(long organismId, no.uib.cipr.matrix.Vector labels) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(organismId);
        DenseVector selection = new DenseVector(nodeIds.getNodeIds().length);
        for (int i = 0; i < labels.size(); ++i) {
            if (labels.get(i) != 1.0) continue;
            selection.set(i, 1.0);
        }
        return selection;
    }

    DenseVector computeSelectionMask(long organismId, Collection<Long> nodes) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(organismId);
        DenseVector selection = new DenseVector(nodeIds.getNodeIds().length);
        for (long id : nodes) {
            int index = nodeIds.getIndexForId(id);
            selection.set(index, 1.0);
        }
        return selection;
    }

    void computePVals(Matrix annotations, int numCategories, int numGenes, DenseVector backgroundCounts, DenseVector sampleCounts, DenseVector selection, DenseVector pvals) {
        int N = numGenes;
        int n = MatrixUtils.countMatches((no.uib.cipr.matrix.Vector)selection, 1.0);
        for (int category = 0; category < numCategories; ++category) {
            long x = Math.round(sampleCounts.get(category));
            long k = Math.round(backgroundCounts.get(category));
            double pval = ComputeEnrichment.computeCumulHyperGeo(x, N, n, k);
            pvals.set(category, pval);
        }
    }
}

