/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.infotheory.impl;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.util.infotheory.WeightedInformationTheory;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.WeightCountTuple;

public class WeightedPairDistribution<T1, T2> {
    public final long count;
    private final Map<CachedPair<T1, T2>, WeightCountTuple> jointCounts;
    private final Map<T1, WeightCountTuple> firstCount;
    private final Map<T2, WeightCountTuple> secondCount;

    public WeightedPairDistribution(long count, Map<CachedPair<T1, T2>, WeightCountTuple> jointCounts, Map<T1, WeightCountTuple> firstCount, Map<T2, WeightCountTuple> secondCount) {
        this.count = count;
        this.jointCounts = new LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple>(jointCounts);
        this.firstCount = new LinkedHashMap<T1, WeightCountTuple>(firstCount);
        this.secondCount = new LinkedHashMap<T2, WeightCountTuple>(secondCount);
    }

    public WeightedPairDistribution(long count, LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple> jointCounts, LinkedHashMap<T1, WeightCountTuple> firstCount, LinkedHashMap<T2, WeightCountTuple> secondCount) {
        this.count = count;
        this.jointCounts = jointCounts;
        this.firstCount = firstCount;
        this.secondCount = secondCount;
    }

    public Map<CachedPair<T1, T2>, WeightCountTuple> getJointCounts() {
        return this.jointCounts;
    }

    public Map<T1, WeightCountTuple> getFirstCount() {
        return this.firstCount;
    }

    public Map<T2, WeightCountTuple> getSecondCount() {
        return this.secondCount;
    }

    public static <T1, T2> WeightedPairDistribution<T1, T2> constructFromLists(List<T1> first, List<T2> second, List<Double> weights) {
        LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple> countDist = new LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple>(20);
        LinkedHashMap<Object, WeightCountTuple> aCountDist = new LinkedHashMap<Object, WeightCountTuple>(20);
        LinkedHashMap<Object, WeightCountTuple> bCountDist = new LinkedHashMap<Object, WeightCountTuple>(20);
        if (first.size() == second.size() && first.size() == weights.size()) {
            long count = 0L;
            for (int i = 0; i < first.size(); ++i) {
                T1 a = first.get(i);
                T2 b = second.get(i);
                double weight = weights.get(i);
                CachedPair<T1, T2> pair = new CachedPair<T1, T2>(a, b);
                WeightCountTuple abCurCount = countDist.computeIfAbsent(pair, k -> new WeightCountTuple());
                abCurCount.weight += weight;
                ++abCurCount.count;
                WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a, k -> new WeightCountTuple());
                aCurCount.weight += weight;
                ++aCurCount.count;
                WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b, k -> new WeightCountTuple());
                bCurCount.weight += weight;
                ++bCurCount.count;
                ++count;
            }
            WeightedInformationTheory.normaliseWeights(countDist);
            WeightedInformationTheory.normaliseWeights(aCountDist);
            WeightedInformationTheory.normaliseWeights(bCountDist);
            return new WeightedPairDistribution<T1, T2>(count, countDist, aCountDist, bCountDist);
        }
        throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", weights.size() = " + weights.size());
    }

    public static <T1, T2> WeightedPairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, WeightCountTuple> jointCount) {
        LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple> countDist = new LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple>(jointCount);
        LinkedHashMap<Object, WeightCountTuple> aCountDist = new LinkedHashMap<Object, WeightCountTuple>(20);
        LinkedHashMap<Object, WeightCountTuple> bCountDist = new LinkedHashMap<Object, WeightCountTuple>(20);
        long count = 0L;
        for (Map.Entry<CachedPair<T1, T2>, WeightCountTuple> e : countDist.entrySet()) {
            CachedPair<T1, T2> pair = e.getKey();
            WeightCountTuple tuple = e.getValue();
            Object a = pair.getA();
            Object b = pair.getB();
            double weight = tuple.weight * (double)tuple.count;
            WeightCountTuple aCurCount = aCountDist.computeIfAbsent(a, k -> new WeightCountTuple());
            aCurCount.weight += weight;
            aCurCount.count += tuple.count;
            WeightCountTuple bCurCount = bCountDist.computeIfAbsent(b, k -> new WeightCountTuple());
            bCurCount.weight += weight;
            bCurCount.count += tuple.count;
            count += tuple.count;
        }
        WeightedInformationTheory.normaliseWeights(aCountDist);
        WeightedInformationTheory.normaliseWeights(bCountDist);
        return new WeightedPairDistribution<T1, T2>(count, countDist, aCountDist, bCountDist);
    }
}

