/*
 * Decompiled with CFR 0.152.
 */
package org.ujmp.core.doublematrix.calculation.general.statistical;

import java.util.Arrays;
import java.util.HashMap;
import org.ujmp.core.Matrix;
import org.ujmp.core.MatrixFactory;
import org.ujmp.core.doublematrix.DoubleMatrix2D;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.enums.ValueType;
import org.ujmp.core.exceptions.MatrixException;
import org.ujmp.core.intmatrix.IntMatrix2D;
import org.ujmp.core.intmatrix.impl.DefaultDenseIntMatrix2D;
import org.ujmp.core.util.MathUtil;

public class MutualInformation
extends AbstractDoubleCalculation {
    private static final long serialVersionUID = -4891250637894943873L;

    public MutualInformation(Matrix matrix) {
        super(matrix);
    }

    public double getDouble(long ... coordinates) throws MatrixException {
        return MutualInformation.calculate(coordinates[0], coordinates[1], this.getSource());
    }

    public long[] getSize() {
        return new long[]{this.getSource().getColumnCount(), this.getSource().getColumnCount()};
    }

    public static final double calculate(long var1, long var2, Matrix matrix) {
        double count = matrix.getRowCount();
        HashMap<Double, Double> count1 = new HashMap<Double, Double>();
        HashMap<Double, Double> count2 = new HashMap<Double, Double>();
        HashMap<String, Double> count12 = new HashMap<String, Double>();
        int r = 0;
        while ((long)r < matrix.getRowCount()) {
            double value1 = matrix.getAsDouble(r, var1);
            double value2 = matrix.getAsDouble(r, var2);
            Double c1 = (Double)count1.get(value1);
            c1 = c1 == null ? 0.0 : c1;
            count1.put(value1, c1 + 1.0);
            Double c2 = (Double)count2.get(value2);
            c2 = c2 == null ? 0.0 : c2;
            count2.put(value2, c2 + 1.0);
            Double c12 = (Double)count12.get(String.valueOf(value1) + "," + value2);
            c12 = c12 == null ? 0.0 : c12;
            count12.put(String.valueOf(value1) + "," + value2, c12 + 1.0);
            ++r;
        }
        for (Double value1 : count1.keySet()) {
            Double c1 = (Double)count1.get(value1);
            count1.put(value1, c1 / count);
        }
        for (Double value2 : count2.keySet()) {
            Double c2 = (Double)count2.get(value2);
            count2.put(value2, c2 / count);
        }
        for (String value12 : count12.keySet()) {
            Double c12 = (Double)count12.get(value12);
            count12.put(value12, c12 / count);
        }
        double mutualInformation = 0.0;
        for (Double value1 : count1.keySet()) {
            double p1 = (Double)count1.get(value1);
            for (Double value2 : count2.keySet()) {
                double p2 = (Double)count2.get(value2);
                Double p12 = (Double)count12.get(value1 + "," + value2);
                if (p12 == null) continue;
                mutualInformation += p12 * MathUtil.log2(p12 / (p1 * p2));
            }
        }
        return mutualInformation;
    }

    public static DoubleMatrix2D calcNew(Matrix matrix) {
        return MutualInformation.calcNew(matrix.convert(ValueType.INT));
    }

    public static DoubleMatrix2D calcNew(IntMatrix2D matrix) {
        DefaultDenseIntMatrix2D matrix2 = (DefaultDenseIntMatrix2D)matrix;
        long count = matrix.getColumnCount();
        int samples = (int)matrix.getRowCount();
        DoubleMatrix2D result = (DoubleMatrix2D)MatrixFactory.zeros(ValueType.DOUBLE, count, count);
        int[] d_dc = new int[(int)count];
        Arrays.fill(d_dc, (int)matrix.getMaxValue() + 1);
        int a = 0;
        while ((long)a < count) {
            int b = 0;
            while (b <= a) {
                double mutual = 0.0;
                double[][] Nab = new double[d_dc[a]][d_dc[b]];
                double[] Na = new double[d_dc[a]];
                double[] Nb = new double[d_dc[b]];
                int k = (int)matrix.getRowCount() - 1;
                while (k >= 0) {
                    int aVal = matrix2.getInt(k, a);
                    int bVal = matrix2.getInt(k, b);
                    int n = aVal;
                    Na[n] = Na[n] + 1.0;
                    int n2 = bVal;
                    Nb[n2] = Nb[n2] + 1.0;
                    double[] dArray = Nab[aVal];
                    int n3 = bVal;
                    dArray[n3] = dArray[n3] + 1.0;
                    --k;
                }
                double[] NaLog = new double[d_dc[a]];
                double[] NbLog = new double[d_dc[b]];
                double log2 = Math.log(2.0);
                int j = d_dc[b] - 1;
                while (j >= 0) {
                    int n = j;
                    Nb[n] = Nb[n] / (double)samples;
                    if (Nb[j] != 0.0) {
                        NbLog[j] = Math.log(Nb[j]);
                    }
                    --j;
                }
                int i = d_dc[a] - 1;
                while (i >= 0) {
                    int n = i;
                    Na[n] = Na[n] / (double)samples;
                    if (Na[i] != 0.0) {
                        NaLog[i] = Math.log(Na[i]);
                    }
                    int j2 = d_dc[b] - 1;
                    while (j2 >= 0) {
                        double[] dArray = Nab[i];
                        int n4 = j2;
                        dArray[n4] = dArray[n4] / (double)samples;
                        if (Na[i] != 0.0 && Nb[j2] != 0.0 && Nab[i][j2] != 0.0) {
                            mutual += Nab[i][j2] * (Math.log(Nab[i][j2]) - NaLog[i] - NbLog[j2]) / log2;
                        }
                        --j2;
                    }
                    --i;
                }
                mutual = mutual < 0.0 ? 0.0 : mutual;
                result.setDouble(mutual, a, b);
                result.setDouble(mutual, b, a);
                ++b;
            }
            ++a;
        }
        return result;
    }
}

