/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.actions;

import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.dto.AttributeDto;
import org.genemania.dto.InteractionDto;
import org.genemania.dto.NetworkDto;
import org.genemania.dto.NodeDto;
import org.genemania.dto.RelatedGenesEngineRequestDto;
import org.genemania.dto.RelatedGenesEngineResponseDto;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.DataSupport;
import org.genemania.engine.core.data.Network;
import org.genemania.engine.core.data.NetworkIds;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureWeightMap;
import org.genemania.engine.core.mania.CoreMania;
import org.genemania.engine.core.utils.Logging;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.labels.LabelVectorGenerator;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.type.CombiningMethod;

public class FindRelated {
    private static Logger logger = Logger.getLogger(FindRelated.class);
    private DataCache cache;
    private RelatedGenesEngineRequestDto request;
    private int numRequestNetworks;
    private int numRequestAttributeGroups;
    private boolean hasUserNetworks;
    private boolean hasUserAttributes;
    static final double posLabelValue = 1.0;
    static final double negLabelValue = -1.0;
    static final double unLabeledValueProduction = -1.0;
    static final double unLabeledValueValidation = 0.0;
    private long requestStartTimeMillis;
    private long requestEndTimeMillis;

    public FindRelated(DataCache cache, RelatedGenesEngineRequestDto request) {
        this.cache = cache;
        this.request = request;
    }

    public RelatedGenesEngineResponseDto process() throws ApplicationException {
        try {
            this.requestStartTimeMillis = System.currentTimeMillis();
            this.logStart();
            this.checkQuery();
            this.logQuery();
            ArrayList<Long> negativeNodes = new ArrayList<Long>();
            DenseVector labels = LabelVectorGenerator.createLabelsFromIds(this.cache.getNodeIds(this.request.getOrganismId()), this.request.getPositiveNodes(), negativeNodes, 1.0, -1.0, -1.0);
            String goCategory = null;
            Constants.CombiningMethod combiningMethod = Constants.convertCombiningMethod(this.request.getCombiningMethod(), this.request.getPositiveNodes().size());
            Constants.ScoringMethod scoringMethod = Constants.convertScoringMethod(this.request.getScoringMethod());
            Collection idList = this.request.getInteractionNetworks();
            CoreMania coreMania = new CoreMania(this.cache, this.request.getProgressReporter());
            coreMania.compute(this.safeGetNamespace(), this.request.getOrganismId(), (Vector)labels, combiningMethod, idList, this.request.getAttributeGroups(), this.request.getAttributesLimit(), goCategory, "average");
            SymMatrix partiallyCombinedKernel = coreMania.getPartiallyCombinedKernel();
            FeatureWeightMap featureWeights = coreMania.getFeatureWeights();
            Vector discriminant = coreMania.getDiscriminant();
            Vector score = this.convertScore(scoringMethod, discriminant, partiallyCombinedKernel, (Vector)labels, 1.0, -1.0);
            double scoreThreshold = this.selectScoreThreshold(scoringMethod);
            RelatedGenesEngineResponseDto response = this.prepareResponse(score, discriminant, featureWeights, partiallyCombinedKernel, scoreThreshold, scoringMethod, Constants.convertCombiningMethod(combiningMethod));
            this.requestEndTimeMillis = System.currentTimeMillis();
            this.logEnd();
            return response;
        }
        catch (CancellationException e) {
            logger.info((Object)"request was cancelled");
            return null;
        }
    }

    private void encodeAttributes(RelatedGenesEngineResponseDto response, int[] indicesForTopScores, FeatureWeightMap featureWeights) throws ApplicationException {
        if (this.request.getAttributeGroups() == null || this.request.getAttributeGroups().size() == 0) {
            this.setEmptyAttributeResponse(response);
            return;
        }
        Map<Long, AttributeDto> allAttributeDtos = this.makeAllAttributeDtos(response, featureWeights);
        this.addAttributesForSelectedNodes(response, allAttributeDtos, featureWeights);
    }

    private void setEmptyAttributeResponse(RelatedGenesEngineResponseDto response) {
        HashMap nodeToAttributes = new HashMap();
        response.setNodeToAttributes(nodeToAttributes);
    }

    private void addAttributesForSelectedNodes(RelatedGenesEngineResponseDto response, Map<Long, AttributeDto> allAttributeDtos, FeatureWeightMap features) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        AttributeGroups attributeGroups = this.cache.getAttributeGroups(this.safeGetNamespace(), this.request.getOrganismId());
        HashMap<Long, Collection<AttributeDto>> nodeToAttributes = new HashMap<Long, Collection<AttributeDto>>();
        for (Feature feature : features.keySet()) {
            if (feature.getType() != Constants.NetworkType.ATTRIBUTE_VECTOR) continue;
            if (this.request.getProgressReporter().isCanceled()) {
                throw new CancellationException();
            }
            long groupId = feature.getGroupId();
            long attributeId = feature.getId();
            AttributeDto attributeDto = allAttributeDtos.get(attributeId);
            AttributeData attributeSet = this.cache.getAttributeData(this.safeGetNamespace(), this.request.getOrganismId(), groupId);
            int attributeIndex = attributeGroups.getIndexForAttributeId(feature.getGroupId(), feature.getId());
            org.genemania.engine.matricks.Matrix data = attributeSet.getData();
            for (NodeDto node : response.getNodes()) {
                int nodeIndex = nodeIds.getIndexForId(node.getId());
                if (data.get(nodeIndex, attributeIndex) == 0.0) continue;
                this.updateNodeAttributes(node.getId(), nodeToAttributes, attributeDto);
            }
        }
        response.setNodeToAttributes(nodeToAttributes);
    }

    private void updateNodeAttributes(long nodeId, Map<Long, Collection<AttributeDto>> nodeToAttribute, AttributeDto attributeDto) {
        Collection<AttributeDto> attributes = nodeToAttribute.get(nodeId);
        if (attributes == null) {
            attributes = new HashSet<AttributeDto>();
            nodeToAttribute.put(nodeId, attributes);
        }
        attributes.add(attributeDto);
    }

    private Map<Long, AttributeDto> makeAllAttributeDtos(RelatedGenesEngineResponseDto response, FeatureWeightMap featureWeights) throws ApplicationException {
        HashMap<Long, AttributeDto> allMap = new HashMap<Long, AttributeDto>();
        ArrayList<AttributeDto> all = new ArrayList<AttributeDto>();
        for (Feature feature : featureWeights.keySet()) {
            if (feature.getType() != Constants.NetworkType.ATTRIBUTE_VECTOR || (Double)featureWeights.get(feature) <= 0.0) continue;
            if (this.request.getProgressReporter().isCanceled()) {
                throw new CancellationException();
            }
            AttributeDto attributeDto = new AttributeDto();
            attributeDto.setId(feature.getId());
            attributeDto.setGroupId(feature.getGroupId());
            attributeDto.setWeight(((Double)featureWeights.get(feature)).doubleValue());
            allMap.put(attributeDto.getId(), attributeDto);
            all.add(attributeDto);
        }
        response.setAttributes(all);
        return allMap;
    }

    private double selectScoreThreshold(Constants.ScoringMethod scoringMethod) {
        if (scoringMethod == Constants.ScoringMethod.ZSCORE) {
            return Double.NEGATIVE_INFINITY;
        }
        return 0.0;
    }

    private Vector convertScore(Constants.ScoringMethod scoringMethod, Vector discriminant, SymMatrix combinedKernel, Vector labels, double posLabelValue, double negLabelValue) throws ApplicationException {
        Vector score;
        if (scoringMethod == Constants.ScoringMethod.DISCRIMINANT) {
            discriminant.set(MatrixUtils.rescale(discriminant));
            score = discriminant;
        } else {
            if (scoringMethod == Constants.ScoringMethod.CONTEXT) {
                throw new ApplicationException("context score no longer supported");
            }
            if (scoringMethod == Constants.ScoringMethod.ZSCORE) {
                score = this.computeZScore(discriminant, combinedKernel, labels, posLabelValue, negLabelValue);
            } else {
                throw new ApplicationException("Unexpected scoring method: " + (Object)((Object)scoringMethod));
            }
        }
        return score;
    }

    private Vector computeZScore(Vector discriminant, SymMatrix combinedKernel, Vector labels, double posLabelValue, double negLabelValue) throws ApplicationException {
        int i;
        logger.debug((Object)"computing z-score");
        DenseVector degrees = new DenseVector(discriminant.size());
        combinedKernel.columnSums(degrees.getData());
        DenseMatrix score = new DenseMatrix(discriminant.size(), 1);
        int n = 0;
        for (i = 0; i < discriminant.size(); ++i) {
            if (degrees.get(i) > 0.0) {
                score.set(i, 0, discriminant.get(i));
                ++n;
                continue;
            }
            score.set(i, 0, Double.NaN);
        }
        logger.debug((Object)("# of nodes with +ve degree in combined network: " + n));
        for (i = 0; i < labels.size(); ++i) {
            if (labels.get(i) != posLabelValue) continue;
            logger.debug((Object)("clearing modes with postive label value for " + i));
            score.set(i, 0, Double.NaN);
        }
        Vector counts = MatrixUtils.columnCountsIgnoreMissingData((Matrix)score);
        Vector zscores = null;
        if (counts.get(0) == 0.0) {
            logger.info((Object)"no nodes connected to query nodes, special casing z-scores");
            zscores = discriminant.copy();
            FindRelated.seteq(zscores, Double.NEGATIVE_INFINITY);
            FindRelated.setmatches(posLabelValue, labels, 1.0, zscores);
        } else {
            Vector means = MatrixUtils.columnMeanIgnoreMissingData((Matrix)score, counts);
            Vector stdevs = MatrixUtils.columnVarianceIgnoreMissingData((Matrix)score, means);
            MatrixUtils.sqrt(stdevs);
            logger.debug((Object)("count, mean, stdev: " + counts.get(0) + ", " + means.get(0) + ", " + stdevs.get(0)));
            zscores = discriminant.copy();
            MatrixUtils.add(zscores, -means.get(0));
            zscores.scale(1.0 / (stdevs.get(0) + 0.01));
            logger.debug((Object)("max of z-scores: " + MatrixUtils.max(zscores)));
        }
        return zscores;
    }

    private static void setlt(Vector v, double thresh, double newval) {
        int n = v.size();
        for (int i = 0; i < n; ++i) {
            if (!(v.get(i) < thresh)) continue;
            v.set(i, newval);
        }
    }

    private static void setge(Vector v, double thresh, double newval) {
        int n = v.size();
        for (int i = 0; i < n; ++i) {
            if (!(v.get(i) >= thresh)) continue;
            v.set(i, newval);
        }
    }

    private static void seteq(Vector v, double newval) {
        int n = v.size();
        for (int i = 0; i < n; ++i) {
            v.set(i, newval);
        }
    }

    private static void setmatches(double needle, Vector haystack, double newneedle, Vector newhaystack) {
        int n = haystack.size();
        for (int i = 0; i < n; ++i) {
            if (haystack.get(i) != needle) continue;
            newhaystack.set(i, newneedle);
        }
    }

    protected RelatedGenesEngineResponseDto prepareResponse(Vector score, Vector discriminant, FeatureWeightMap featureWeights, SymMatrix combinedKernel, double scoreThreshold, Constants.ScoringMethod scoringMethod, CombiningMethod combiningMethod) throws ApplicationException {
        this.logPreparingOutputs();
        RelatedGenesEngineResponseDto response = new RelatedGenesEngineResponseDto();
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        List<Integer> indicesForPositiveNodes = nodeIds.getIndicesForIds(this.request.getPositiveNodes());
        int[] indicesForTopScores = scoringMethod == Constants.ScoringMethod.CONTEXT ? MatrixUtils.getIndicesForTopScores(discriminant, indicesForPositiveNodes, this.request.getLimitResults(), scoreThreshold) : MatrixUtils.getIndicesForTopScores(score, indicesForPositiveNodes, this.request.getLimitResults(), scoreThreshold);
        logger.debug((Object)String.format("number of nodes available for return: %d", indicesForTopScores.length));
        if (this.request.getProgressReporter().isCanceled()) {
            throw new CancellationException();
        }
        logger.debug((Object)"extracting source interactions");
        this.getSourceInteractions(response, indicesForTopScores, score, featureWeights);
        logger.debug((Object)"extracting attributes");
        this.encodeAttributes(response, indicesForTopScores, featureWeights);
        response.setCombiningMethodApplied(combiningMethod);
        return response;
    }

    public Collection<InteractionDto> matrixToInteractions(SymMatrix network, int[] indicesForTopScores, HashMap<Long, NodeDto> nodeVOs) throws ApplicationException {
        ArrayList<InteractionDto> interactions = new ArrayList<InteractionDto>();
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        for (int i = 0; i < indicesForTopScores.length; ++i) {
            for (int j = 0; j < i; ++j) {
                int idx = indicesForTopScores[i];
                int jdx = indicesForTopScores[j];
                long from = nodeIds.getIdForIndex(idx);
                long to = nodeIds.getIdForIndex(jdx);
                double weight = network.get(idx, jdx);
                if (weight == 0.0) continue;
                NodeDto fromNodeVO = nodeVOs.get(from);
                NodeDto toNodeVO = nodeVOs.get(to);
                if (fromNodeVO == null || toNodeVO == null) {
                    throw new ApplicationException("mapping error");
                }
                InteractionDto interaction = new InteractionDto();
                interaction.setNodeVO1(fromNodeVO);
                interaction.setNodeVO2(toNodeVO);
                interaction.setWeight(weight);
                interactions.add(interaction);
            }
        }
        return interactions;
    }

    public void getSourceInteractions(RelatedGenesEngineResponseDto response, int[] indicesForTopScores, Vector scores, FeatureWeightMap featureWeights) throws ApplicationException {
        ArrayList<NetworkDto> sourceNetworks = new ArrayList<NetworkDto>();
        HashMap<Long, NodeDto> nodeVOs = new HashMap<Long, NodeDto>();
        for (int i = 0; i < indicesForTopScores.length; ++i) {
            NodeDto nodeVO = new NodeDto();
            NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
            long nodeId = nodeIds.getIdForIndex(indicesForTopScores[i]);
            double score = scores.get(indicesForTopScores[i]);
            nodeVO.setId(nodeId);
            nodeVO.setScore(score);
            nodeVOs.put(nodeId, nodeVO);
        }
        for (Feature feature : featureWeights.keySet()) {
            if (feature.getType() != Constants.NetworkType.SPARSE_MATRIX) continue;
            if (this.request.getProgressReporter().isCanceled()) {
                throw new CancellationException();
            }
            Double weight = (Double)featureWeights.get(feature);
            long networkId = feature.getId();
            if (weight != null && weight == 0.0) {
                logger.debug((Object)String.format("network %s has zero weight, excluding from results", networkId));
            }
            if (weight == null || weight == 0.0) continue;
            NetworkDto sourceNetwork = new NetworkDto();
            sourceNetwork.setWeight(weight.doubleValue());
            sourceNetwork.setId(networkId);
            Network network = this.cache.getNetwork(this.safeGetNamespace(), this.request.getOrganismId(), networkId);
            Collection<InteractionDto> sourceInteractions = this.matrixToInteractions(network.getData(), indicesForTopScores, nodeVOs);
            sourceNetwork.setInteractions(sourceInteractions);
            sourceNetworks.add(sourceNetwork);
        }
        response.setNetworks(sourceNetworks);
        ArrayList nodes = new ArrayList();
        nodes.addAll(nodeVOs.values());
        response.setNodes(nodes);
    }

    private void logQuery() {
        logger.info((Object)String.format("findRelated query using combining method %s for organism %d contains %d nodes, %d network groups, %d networks, %d attribute groups, and requests a maximum of %d related nodes using a maximum of %d attributes per group", this.request.getCombiningMethod(), this.request.getOrganismId(), this.request.getPositiveNodes().size(), this.request.getInteractionNetworks().size(), this.numRequestNetworks, this.numRequestAttributeGroups, this.request.getLimitResults(), this.request.getAttributesLimit()));
    }

    private void logStart() {
        logger.info((Object)"processing findRelated() request");
        this.request.getProgressReporter().setMaximumProgress(5);
        this.request.getProgressReporter().setStatus("starting");
        this.request.getProgressReporter().setProgress(0);
    }

    private void logPreparingOutputs() {
        logger.info((Object)"preparing outputs for findRelated() request");
        this.request.getProgressReporter().setDescription("preparing outputs");
        this.request.getProgressReporter().setProgress(4);
    }

    private void logEnd() {
        logger.info((Object)("completed processing request, duration = " + Logging.duration(this.requestStartTimeMillis, this.requestEndTimeMillis)));
        this.request.getProgressReporter().setStatus("done");
        this.request.getProgressReporter().setProgress(5);
    }

    private void logNodeScores(int[] indicesForTopScores, Vector discriminant) throws ApplicationException {
        if (logger.isDebugEnabled()) {
            NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
            for (int i = 0; i < indicesForTopScores.length; ++i) {
                long nodeId = nodeIds.getIdForIndex(indicesForTopScores[i]);
                double nodeScore = discriminant.get(indicesForTopScores[i]);
                logger.debug((Object)String.format("Node %d as a score of %f", nodeId, nodeScore));
            }
        }
    }

    public void checkQuery() throws ApplicationException {
        if (this.request.getPositiveNodes() == null || this.request.getPositiveNodes().size() == 0) {
            throw new ApplicationException("No query nodes given");
        }
        if (!(this.request.getInteractionNetworks() != null && this.request.getInteractionNetworks().size() != 0 || this.request.getAttributeGroups() != null && this.request.getAttributeGroups().size() != 0)) {
            throw new ApplicationException("No query networks or attributes given");
        }
        this.checkNodes(this.request.getOrganismId(), this.request.getPositiveNodes());
        this.hasUserNetworks = DataSupport.queryHasUserNetworks(this.request.getInteractionNetworks());
        this.hasUserAttributes = DataSupport.queryHasUserAttributes(this.request.getAttributeGroups());
        this.numRequestNetworks = this.checkNetworks(this.safeGetNamespace(), this.request.getOrganismId(), this.request.getInteractionNetworks());
        if (this.request.getAttributeGroups() == null) {
            this.request.setAttributeGroups(new ArrayList());
        }
        this.numRequestAttributeGroups = this.checkAttributeGroups(this.safeGetNamespace(), this.request.getOrganismId(), this.request.getAttributeGroups());
    }

    protected void checkNodes(long organismId, Collection<Long> nodes) throws ApplicationException {
        if (nodes.size() == 0) {
            throw new ApplicationException("the list of nodes in the request is empty");
        }
        HashSet uniqueNodeIds = new HashSet();
        NodeIds nodeIds = this.cache.getNodeIds(organismId);
        for (Long nodeId : nodes) {
            if (uniqueNodeIds.contains(nodeId)) {
                throw new ApplicationException(String.format("the node id %d was passed multiple times in request", nodeId));
            }
            long n = nodeId;
            try {
                nodeIds.getIndexForId(n);
            }
            catch (ApplicationException e) {
                throw new ApplicationException(String.format("node id %d is not valid for organism id %d", nodeId, organismId));
            }
        }
    }

    protected int checkNetworks(String namespace, long organismId, Collection<Collection<Long>> networks) throws ApplicationException {
        HashSet<Long> uniqueNetworkIds = new HashSet<Long>();
        NetworkIds networkIds = this.cache.getNetworkIds(namespace, organismId);
        for (Collection<Long> grouping : networks) {
            for (Long networkId : grouping) {
                if (uniqueNetworkIds.contains(networkId)) {
                    throw new ApplicationException(String.format("the network id %d was passed multiple times in request", networkId));
                }
                uniqueNetworkIds.add(networkId);
                long n = networkId;
                if (n > Integer.MAX_VALUE || n < Integer.MIN_VALUE) {
                    throw new ApplicationException(String.format("network ids must be in integer range, got id: %d", networkId));
                }
                if (n < 0L) {
                    if (namespace == null) {
                        throw new ApplicationException(String.format("no namespace provided for user network %d organism %d", networkId, organismId));
                    }
                    logger.warn((Object)("skipping validation check on user network: " + n));
                }
                try {
                    networkIds.getIndexForId(n);
                }
                catch (ApplicationException e) {
                    throw new ApplicationException(String.format("network id %d is not valid for organism id %d", networkId, organismId));
                }
            }
        }
        return uniqueNetworkIds.size();
    }

    protected int checkAttributeGroups(String namespace, long organismId, Collection<Long> attributeGroups) throws ApplicationException {
        if (attributeGroups.size() == 0) {
            return 0;
        }
        HashSet<Long> uniqueIds = new HashSet<Long>(attributeGroups);
        if (attributeGroups.size() != uniqueIds.size()) {
            throw new ApplicationException("the list of attribute groups contains duplicates");
        }
        AttributeGroups ids = this.cache.getAttributeGroups(namespace, organismId);
        HashMap<Long, ArrayList<Long>> groupMap = ids.getAttributeGroups();
        for (long groupId : attributeGroups) {
            if (groupMap.containsKey(groupId)) continue;
            throw new ApplicationException(String.format("organism %d in namspace '%s' does not contain the attribute group %d", organismId, namespace, groupId));
        }
        return attributeGroups.size();
    }

    private String safeGetNamespace() {
        String namespace = this.request.getNamespace();
        if (namespace == null || namespace.equals("")) {
            return "CORE";
        }
        if (!this.hasUserNetworks && !this.hasUserAttributes) {
            return "CORE";
        }
        return namespace;
    }

    private void dumpNumbers(String fileName, Vector discriminant, Vector labels, Vector degrees) {
        try {
            logger.info((Object)("dumping to " + fileName));
            File file = new File(fileName);
            FileWriter writer = new FileWriter(file);
            int n = discriminant.size();
            String header = "node\tdiscriminant\tlabels\tdegrees\n";
            writer.write(header);
            for (int i = 0; i < n; ++i) {
                String line = String.format("%d\t%.15e\t%.15e\t%.15e\n", i, discriminant.get(i), labels.get(i), degrees.get(i));
                writer.write(line);
            }
            writer.close();
        }
        catch (Exception e) {
            logger.warn((Object)"failed to dump data", (Throwable)e);
        }
    }

    public static void logInteractions(long networkId, Collection<InteractionDto> interactions) {
        logger.debug((Object)("interactions for network " + networkId));
        for (InteractionDto i : interactions) {
            logger.debug((Object)String.format("   %d %d %f", i.getNodeVO1().getId(), i.getNodeVO2().getId(), i.getWeight()));
        }
    }
}

