package eu.bandm.alea.data;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;

/* loaded from: input_file:eu/bandm/alea/data/Distribution.class */
public class Distribution<V> {
    private final Map<V, Rational> probs;
    static final /* synthetic */ boolean $assertionsDisabled;

    private Distribution(Map<V, Rational> map) {
        this.probs = (Map) Objects.requireNonNull(map);
        if (!$assertionsDisabled && !isValid()) {
            throw new AssertionError();
        }
    }

    private boolean isValid() {
        Rational rational = Rational.ZERO;
        for (Rational rational2 : this.probs.values()) {
            if (rational2.compareTo(Rational.ZERO) < 0 || rational2.compareTo(Rational.ONE) > 0) {
                return false;
            }
            rational = rational.add(rational2);
        }
        return rational.equals(Rational.ONE);
    }

    public String toString() {
        return this.probs.toString();
    }

    public Set<V> support() {
        return this.probs.keySet();
    }

    public Rational get(V v) {
        return this.probs.getOrDefault(v, Rational.ZERO);
    }

    public Map<V, Rational> getProbabilities() {
        return this.probs;
    }

    public static <V> Distribution<V> delta(V v) {
        return uniform(v);
    }

    @SafeVarargs
    public static <V> Distribution<V> uniform(V... vArr) {
        return uniform(Arrays.asList(vArr));
    }

    public static <V> Distribution<V> uniform(Collection<? extends V> collection) {
        int size = collection.size();
        if (size == 0) {
            throw new IllegalArgumentException(collection.toString());
        }
        if (size == 1) {
            return new Distribution<>(Collections.singletonMap(collection.iterator().next(), Rational.ONE));
        }
        Rational of = Rational.of(1L, size);
        HashMap hashMap = new HashMap();
        Iterator<? extends V> it = collection.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), of);
        }
        return new Distribution<>(hashMap);
    }

    public static Distribution<Boolean> bernoulli(Rational rational) {
        return bernoulli(rational, true, false);
    }

    public static Distribution<Boolean> bernoulli() {
        return bernoulli(Rational.of(1L, 2L));
    }

    public static <T> Distribution<T> bernoulli(Rational rational, T t, T t2) {
        HashMap hashMap = new HashMap();
        hashMap.put(t, rational);
        hashMap.put(t2, Rational.ONE.subtract(rational));
        return new Distribution<>(hashMap);
    }

    public Distribution<V> except(Rational rational, V v) {
        Rational subtract = Rational.ONE.subtract(rational);
        HashMap hashMap = new HashMap();
        for (Map.Entry<V, Rational> entry : this.probs.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().multiply(subtract));
        }
        hashMap.merge(v, rational, (v0, v1) -> {
            return v0.add(v1);
        });
        return new Distribution<>(hashMap);
    }

    public <W> Distribution<W> map(Function<? super V, ? extends W> function) {
        if (this.probs.size() == 1) {
            Map.Entry<V, Rational> next = this.probs.entrySet().iterator().next();
            return new Distribution<>(Collections.singletonMap(function.apply(next.getKey()), next.getValue()));
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<V, Rational> entry : this.probs.entrySet()) {
            hashMap.merge(function.apply(entry.getKey()), entry.getValue(), (v0, v1) -> {
                return v0.add(v1);
            });
        }
        return new Distribution<>(hashMap);
    }

    public <W, X> Distribution<X> zipWith(BiFunction<? super V, ? super W, ? extends X> biFunction, Distribution<W> distribution) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<V, Rational> entry : this.probs.entrySet()) {
            for (Map.Entry<W, Rational> entry2 : distribution.probs.entrySet()) {
                hashMap.merge(biFunction.apply(entry.getKey(), entry2.getKey()), entry.getValue().multiply(entry2.getValue()), (v0, v1) -> {
                    return v0.add(v1);
                });
            }
        }
        return new Distribution<>(hashMap);
    }

    public <W> Distribution<W> flatMap(Function<? super V, Distribution<W>> function) {
        if (this.probs.size() == 1) {
            Map.Entry<V, Rational> next = this.probs.entrySet().iterator().next();
            if ($assertionsDisabled || next.getValue().equals(Rational.ONE)) {
                return function.apply(next.getKey());
            }
            throw new AssertionError();
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<V, Rational> entry : this.probs.entrySet()) {
            Rational value = entry.getValue();
            for (Map.Entry<W, Rational> entry2 : function.apply(entry.getKey()).probs.entrySet()) {
                hashMap.merge(entry2.getKey(), value.multiply(entry2.getValue()), (v0, v1) -> {
                    return v0.add(v1);
                });
            }
        }
        return new Distribution<>(hashMap);
    }

    public Distribution<V> squareWith(BinaryOperator<V> binaryOperator) {
        return (Distribution<V>) zipWith(binaryOperator, this);
    }

    public Distribution<V> powerWith(BinaryOperator<V> binaryOperator, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException(i + " <= 0");
        }
        if (i == 1) {
            return this;
        }
        Distribution squareWith = powerWith(binaryOperator, i / 2).squareWith(binaryOperator);
        return i % 2 == 1 ? (Distribution<V>) zipWith(binaryOperator, squareWith) : squareWith;
    }

    public static <K, V> Distribution<Map<K, V>> product(Map<K, Distribution<V>> map) {
        Distribution<Map<K, V>> delta = delta(Collections.emptyMap());
        for (Map.Entry<K, Distribution<V>> entry : map.entrySet()) {
            delta = delta.flatMap(map2 -> {
                return ((Distribution) entry.getValue()).map(obj -> {
                    HashMap hashMap = new HashMap(map2);
                    hashMap.put(entry.getKey(), obj);
                    return hashMap;
                });
            });
        }
        return delta;
    }

    public static <V> Distribution<List<V>> product(List<Distribution<V>> list) {
        Distribution<List<V>> delta = delta(Collections.emptyList());
        for (Distribution<V> distribution : list) {
            delta = delta.flatMap(list2 -> {
                return distribution.map(obj -> {
                    ArrayList arrayList = new ArrayList(list2);
                    arrayList.add(obj);
                    return arrayList;
                });
            });
        }
        return delta;
    }

    public static <V> Distribution<Set<V>> product(Set<Distribution<V>> set) {
        Distribution<Set<V>> delta = delta(Collections.emptySet());
        for (Distribution<V> distribution : set) {
            delta = delta.flatMap(set2 -> {
                return distribution.map(obj -> {
                    HashSet hashSet = new HashSet(set2);
                    hashSet.add(obj);
                    return hashSet;
                });
            });
        }
        return delta;
    }

    public static Rational mean(Distribution<Rational> distribution) {
        Rational rational = Rational.ZERO;
        for (Map.Entry<Rational, Rational> entry : ((Distribution) distribution).probs.entrySet()) {
            rational = rational.add(entry.getKey().multiply(entry.getValue()));
        }
        return rational;
    }

    public V sample(Random random) {
        double d = 0.0d;
        double nextDouble = random.nextDouble();
        V v = null;
        for (Map.Entry<V, Rational> entry : this.probs.entrySet()) {
            d += entry.getValue().doubleValue();
            v = entry.getKey();
            if (d > nextDouble) {
                break;
            }
        }
        return v;
    }

    public RandomVariable<V> sample() {
        return sample(this.probs);
    }

    private static <V> RandomVariable<V> sample(Map<V, Rational> map) {
        if (map.size() == 1) {
            final V next = map.keySet().iterator().next();
            if ($assertionsDisabled || map.get(next).equals(Rational.ONE)) {
                return new RandomVariable<V>() { // from class: eu.bandm.alea.data.Distribution.1
                    @Override // eu.bandm.alea.data.RandomVariable
                    public V sample(Random random) {
                        return (V) next;
                    }

                    public String toString() {
                        return String.valueOf(next);
                    }
                };
            }
            throw new AssertionError();
        }
        if (!$assertionsDisabled && map.size() <= 1) {
            throw new AssertionError();
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Rational rational = Rational.ZERO;
        boolean z = true;
        for (Map.Entry<V, Rational> entry : map.entrySet()) {
            if (z) {
                hashMap.put(entry.getKey(), entry.getValue());
                rational = rational.add(entry.getValue());
                z = rational.compareTo(Rational.HALF) < 0;
            } else {
                hashMap2.put(entry.getKey(), entry.getValue());
            }
        }
        if (hashMap2.isEmpty()) {
        }
        norm(hashMap, rational);
        norm(hashMap2, Rational.ONE.subtract(rational));
        final double doubleValue = rational.doubleValue();
        final RandomVariable sample = sample(hashMap);
        final RandomVariable sample2 = sample(hashMap2);
        return new RandomVariable<V>() { // from class: eu.bandm.alea.data.Distribution.2
            @Override // eu.bandm.alea.data.RandomVariable
            public V sample(Random random) {
                return (V) (random.nextDouble() < doubleValue ? sample : sample2).sample(random);
            }

            public String toString() {
                double d = doubleValue;
                String valueOf = String.valueOf(sample);
                String.valueOf(sample2);
                return d + " ? (" + d + ") : (" + valueOf + ")";
            }
        };
    }

    private static void norm(Map<?, Rational> map, Rational rational) {
        for (Map.Entry<?, Rational> entry : map.entrySet()) {
            entry.setValue(entry.getValue().divide(rational));
        }
    }

    public static Distribution<Integer> binomial(int i, Rational rational) {
        HashMap hashMap = new HashMap();
        BigInteger bigInteger = BigInteger.ONE;
        Rational subtract = Rational.ONE.subtract(rational);
        Rational rational2 = Rational.ONE;
        Rational power = subtract.power(i);
        for (int i2 = 0; i2 <= i; i2++) {
            hashMap.put(Integer.valueOf(i2), rational2.multiply(power).multiply(Rational.of(bigInteger)));
            bigInteger = bigInteger.multiply(BigInteger.valueOf(i - i2)).divide(BigInteger.valueOf(i2 + 1));
            rational2 = rational2.multiply(rational);
            power = power.divide(subtract);
        }
        return new Distribution<>(hashMap);
    }

    public static void main(String[] strArr) {
        System.out.println(binomial(20, Rational.HALF));
    }

    static {
        $assertionsDisabled = !Distribution.class.desiredAssertionStatus();
    }
}
