/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration.calculators;

import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.data.KtKFeatures;
import org.genemania.engine.core.integration.CombineNetworksOnly;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.Solver;
import org.genemania.engine.core.integration.attribute.QueryEnrichedAttributeScorer;
import org.genemania.engine.core.integration.calculators.AbstractNetworkWeightCalculator;
import org.genemania.engine.core.integration.calculators.AverageByNetworkCalculator;
import org.genemania.engine.core.integration.gram.BasicGramBuilder;
import org.genemania.engine.core.integration.gram.GramEditor;
import org.genemania.engine.exception.WeightingFailedException;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

public class BranchSpecificCalculator
extends AbstractNetworkWeightCalculator {
    private static Logger logger = Logger.getLogger(BranchSpecificCalculator.class);
    Constants.CombiningMethod method;
    private static int MIN_QUERY_GENES_PER_ATTRIBUTE = 1;
    public static final String PARAM_KEY_FORMAT = "%s-%s";

    public BranchSpecificCalculator(String namespace, DataCache cache, Collection<Collection<Long>> networkIds, Collection<Long> attributeGroupIds, long organismId, Vector label, int attributesLimit, Constants.CombiningMethod method, ProgressReporter progress) throws ApplicationException {
        super(namespace, cache, networkIds, attributeGroupIds, organismId, label, attributesLimit, progress);
        this.method = method;
    }

    @Override
    public void process() throws ApplicationException {
        this.progress.setStatus("computing network weights");
        this.progress.setProgress(1);
        boolean hasUserNetworks = this.queryHasUserNetworks();
        this.computeNewResult(hasUserNetworks);
    }

    void computeNewResult(boolean hasUserNetworks) throws ApplicationException {
        DenseMatrix KtK2 = this.getKtK(hasUserNetworks);
        DenseMatrix KtT2 = this.getKtT(this.method.toString(), hasUserNetworks);
        KtKFeatures ktkFeatures = this.cache.getKtKFeatures(this.namespace, this.organismId);
        FeatureList KtKFeatureList = ktkFeatures.getFeatures();
        QueryEnrichedAttributeScorer attributeScorer = new QueryEnrichedAttributeScorer(this.cache, this.label, MIN_QUERY_GENES_PER_ATTRIBUTE);
        FeatureList featureList = this.buildFeatureList(attributeScorer, false);
        featureList.addBias();
        FeatureList haveThem = this.intersect(KtKFeatureList, featureList);
        FeatureList needThem = this.setdiff(KtKFeatureList, featureList);
        if (((Feature)haveThem.get(0)).getType() != Constants.NetworkType.BIAS) {
            throw new ApplicationException("internal error: bias must be first column");
        }
        KtK2 = GramEditor.RemoveNetworkKtK(KtK2, KtKFeatureList, haveThem);
        KtT2 = GramEditor.RemoveNetworkKtT(KtT2, KtKFeatureList, haveThem);
        if (needThem.size() > 0) {
            logger.debug((Object)String.format("need to update gram for %d features", needThem.size()));
            BasicGramBuilder builder = new BasicGramBuilder(this.cache, this.namespace, this.organismId, this.progress);
            KtK2 = builder.updateBasicKtK(KtK2, haveThem, needThem, this.progress);
            CoAnnotationSet annoSet = this.cache.getCoAnnotationSet(this.organismId, this.method.toString());
            KtT2 = builder.updateKtT(KtT2, haveThem, needThem, annoSet, this.progress);
            haveThem.addAll(needThem);
        }
        this.scaleKtK(KtK2, this.method.toString());
        try {
            this.weights = Solver.solve((Matrix)KtK2, MatrixUtils.extractColumnToVector((Matrix)KtT2, 0), haveThem, this.progress);
        }
        catch (WeightingFailedException e) {
            logger.error((Object)("weighting calculation failed, falling back to average: " + e.getMessage()));
            this.weights = AverageByNetworkCalculator.average(haveThem);
        }
        this.progress.setStatus("building combined network");
        this.progress.setProgress(2);
        this.combinedMatrix = CombineNetworksOnly.combine(this.weights, this.namespace, this.organismId, this.cache, this.progress);
    }

    private FeatureList intersect(FeatureList A, FeatureList B) {
        HashSet<Feature> sA = new HashSet<Feature>();
        sA.addAll(A);
        HashSet<Feature> sB = new HashSet<Feature>();
        sB.addAll(B);
        sA.retainAll(sB);
        FeatureList features = new FeatureList();
        features.addAll(sA);
        Collections.sort(features);
        return features;
    }

    private FeatureList setdiff(FeatureList A, FeatureList B) {
        HashSet<Feature> sB = new HashSet<Feature>();
        sB.addAll(B);
        sB.removeAll(this.intersect(A, B));
        FeatureList features = new FeatureList();
        features.addAll(sB);
        Collections.sort(features);
        return features;
    }

    @Override
    public String getParameterKey() throws ApplicationException {
        if (this.attributeGroupIds != null && this.attributeGroupIds.size() > 0) {
            throw new ApplicationException("not cacheable");
        }
        String networks = BranchSpecificCalculator.formattedNetworkList(this.networkIds);
        return String.format(PARAM_KEY_FORMAT, this.method.toString(), networks);
    }
}

