/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.shaded.org.apache.commons.statistics.distribution;

import com.hazelcast.shaded.org.apache.commons.numbers.gamma.Erf;
import com.hazelcast.shaded.org.apache.commons.numbers.gamma.ErfDifference;
import com.hazelcast.shaded.org.apache.commons.numbers.gamma.Erfcx;
import com.hazelcast.shaded.org.apache.commons.rng.UniformRandomProvider;
import com.hazelcast.shaded.org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import com.hazelcast.shaded.org.apache.commons.statistics.distribution.AbstractContinuousDistribution;
import com.hazelcast.shaded.org.apache.commons.statistics.distribution.ArgumentUtils;
import com.hazelcast.shaded.org.apache.commons.statistics.distribution.ContinuousDistribution;
import com.hazelcast.shaded.org.apache.commons.statistics.distribution.DistributionException;
import com.hazelcast.shaded.org.apache.commons.statistics.distribution.NormalDistribution;
import java.util.function.DoubleSupplier;

public final class TruncatedNormalDistribution
extends AbstractContinuousDistribution {
    private static final double MAX_X = 1.3407807929942596E154;
    private static final double MIN_P = 0.0;
    private static final double ROOT2 = 1.4142135623730951;
    private static final double ROOT_2_PI = 0.7978845608028654;
    private static final double ROOT_PI_2 = 1.2533141373155003;
    private static final double REJECTION_THRESHOLD = 0.2;
    private final NormalDistribution parentNormal;
    private final double lower;
    private final double upper;
    private final double cdfDelta;
    private final double logCdfDelta;
    private final double cdfAlpha;
    private final double sfBeta;

    private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) {
        this.parentNormal = parent;
        this.lower = lower;
        this.upper = upper;
        this.cdfDelta = z;
        this.logCdfDelta = Math.log(this.cdfDelta);
        this.cdfAlpha = this.parentNormal.cumulativeProbability(lower);
        this.sfBeta = this.parentNormal.survivalProbability(upper);
    }

    public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) {
        if (sd <= 0.0) {
            throw new DistributionException("Number %s is not greater than 0", sd);
        }
        if (lower >= upper) {
            throw new DistributionException("Lower bound %s >= upper bound %s", lower, upper);
        }
        NormalDistribution parent = NormalDistribution.of(mean, sd);
        double z = parent.probability(lower, upper);
        if (z <= 0.0) {
            double a2 = (lower - mean) / sd;
            double b = (upper - mean) / sd;
            throw new DistributionException("Excess truncation of standard normal : CDF(%s, %s) = %s", a2, b, z);
        }
        return new TruncatedNormalDistribution(parent, z, lower, upper);
    }

    @Override
    public double density(double x) {
        if (x < this.lower || x > this.upper) {
            return 0.0;
        }
        return this.parentNormal.density(x) / this.cdfDelta;
    }

    @Override
    public double probability(double x0, double x1) {
        if (x0 > x1) {
            throw new DistributionException("Lower bound %s > upper bound %s", x0, x1);
        }
        return this.parentNormal.probability(this.clipToRange(x0), this.clipToRange(x1)) / this.cdfDelta;
    }

    @Override
    public double logDensity(double x) {
        if (x < this.lower || x > this.upper) {
            return Double.NEGATIVE_INFINITY;
        }
        return this.parentNormal.logDensity(x) - this.logCdfDelta;
    }

    @Override
    public double cumulativeProbability(double x) {
        if (x <= this.lower) {
            return 0.0;
        }
        if (x >= this.upper) {
            return 1.0;
        }
        return this.parentNormal.probability(this.lower, x) / this.cdfDelta;
    }

    @Override
    public double survivalProbability(double x) {
        if (x <= this.lower) {
            return 1.0;
        }
        if (x >= this.upper) {
            return 0.0;
        }
        return this.parentNormal.probability(x, this.upper) / this.cdfDelta;
    }

    @Override
    public double inverseCumulativeProbability(double p) {
        ArgumentUtils.checkProbability(p);
        if (p == 0.0) {
            return this.lower;
        }
        if (p == 1.0) {
            return this.upper;
        }
        double x = this.parentNormal.inverseCumulativeProbability(this.cdfAlpha + p * this.cdfDelta);
        return this.clipToRange(x);
    }

    @Override
    public double inverseSurvivalProbability(double p) {
        ArgumentUtils.checkProbability(p);
        if (p == 1.0) {
            return this.lower;
        }
        if (p == 0.0) {
            return this.upper;
        }
        double x = this.parentNormal.inverseSurvivalProbability(this.sfBeta + p * this.cdfDelta);
        return this.clipToRange(x);
    }

    @Override
    public ContinuousDistribution.Sampler createSampler(UniformRandomProvider rng) {
        double threshold = 0.2;
        if (this.lower >= 0.0 || this.upper <= 0.0) {
            threshold *= 0.5;
        }
        if (this.cdfDelta > threshold) {
            ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of((UniformRandomProvider)rng);
            DoubleSupplier gen = this.lower >= 0.0 ? () -> Math.abs(sampler.sample()) : (this.upper <= 0.0 ? () -> -Math.abs(sampler.sample()) : () -> ((ZigguratSampler.NormalizedGaussian)sampler).sample());
            double u = this.parentNormal.getMean();
            double s = this.parentNormal.getStandardDeviation();
            double a2 = (this.lower - u) / s;
            double b = (this.upper - u) / s;
            return () -> {
                double x = gen.getAsDouble();
                while (x < a2 || x > b) {
                    x = gen.getAsDouble();
                }
                return this.clipToRange(u + x * s);
            };
        }
        return super.createSampler(rng);
    }

    @Override
    public double getMean() {
        double u = this.parentNormal.getMean();
        double s = this.parentNormal.getStandardDeviation();
        double a2 = (this.lower - u) / s;
        double b = (this.upper - u) / s;
        return u + TruncatedNormalDistribution.moment1(a2, b) * s;
    }

    @Override
    public double getVariance() {
        double u = this.parentNormal.getMean();
        double s = this.parentNormal.getStandardDeviation();
        double a2 = (this.lower - u) / s;
        double b = (this.upper - u) / s;
        return TruncatedNormalDistribution.variance(a2, b) * s * s;
    }

    @Override
    public double getSupportLowerBound() {
        return this.lower;
    }

    @Override
    public double getSupportUpperBound() {
        return this.upper;
    }

    private double clipToRange(double x) {
        return TruncatedNormalDistribution.clip(x, this.lower, this.upper);
    }

    private static double clip(double x, double lower, double upper) {
        if (x <= lower) {
            return lower;
        }
        return x < upper ? x : upper;
    }

    static double moment1(double a2, double b) {
        double m;
        if (a2 == b) {
            return a2;
        }
        if (Math.abs(a2) > Math.abs(b)) {
            return 0.0 - TruncatedNormalDistribution.moment1(-b, -a2);
        }
        if (a2 <= -1.3407807929942596E154) {
            return 0.0;
        }
        if (b >= 1.3407807929942596E154) {
            return 0.7978845608028654 / Erfcx.value(a2 / 1.4142135623730951);
        }
        double dx = 0.5 * (b + a2) * (b - a2);
        if (a2 <= 0.0) {
            m = 0.7978845608028654 * -Math.expm1(-dx) * Math.exp(-0.5 * a2 * a2) / ErfDifference.value(a2 / 1.4142135623730951, b / 1.4142135623730951);
        } else {
            double z = Math.exp(-dx) * Erfcx.value(b / 1.4142135623730951) - Erfcx.value(a2 / 1.4142135623730951);
            if (z == 0.0) {
                return (a2 + b) * 0.5;
            }
            m = 0.7978845608028654 * Math.expm1(-dx) / z;
        }
        return TruncatedNormalDistribution.clip(m, a2, b);
    }

    private static double moment2(double a2, double b) {
        double m;
        if (Math.abs(a2) > Math.abs(b)) {
            return TruncatedNormalDistribution.moment2(-b, -a2);
        }
        if (a2 <= -1.3407807929942596E154) {
            return 1.0;
        }
        if (b >= 1.3407807929942596E154) {
            return 1.0 + 0.7978845608028654 * a2 / Erfcx.value(a2 / 1.4142135623730951);
        }
        if (a2 <= 0.0) {
            double ea = 1.2533141373155003 * Erf.value(a2 / 1.4142135623730951);
            double eb = 1.2533141373155003 * Erf.value(b / 1.4142135623730951);
            double fa = ea - a2 * Math.exp(-0.5 * a2 * a2);
            double fb = eb - b * Math.exp(-0.5 * b * b);
            m = (fb - fa) / (eb - ea);
            m = TruncatedNormalDistribution.clip(m, 0.0, 1.0);
        } else {
            double dx = 0.5 * (b + a2) * (b - a2);
            double ex = Math.exp(-dx);
            double ea = 1.2533141373155003 * Erfcx.value(a2 / 1.4142135623730951);
            double eb = 1.2533141373155003 * Erfcx.value(b / 1.4142135623730951);
            double fa = ea + a2;
            double fb = eb + b;
            m = (fa - fb * ex) / (ea - eb * ex);
            m = TruncatedNormalDistribution.clip(m, a2 * a2, b * b);
        }
        return m;
    }

    static double variance(double a2, double b) {
        if (a2 == b) {
            return 0.0;
        }
        double m1 = TruncatedNormalDistribution.moment1(a2, b);
        double m2 = TruncatedNormalDistribution.moment2(a2, b);
        double variance = ((m2 = Math.sqrt(m2)) - m1) * (m2 + m1);
        if (variance >= 1.0) {
            return a2 < -1.0 && b > 1.0 ? 1.0 : 0.0;
        }
        if (variance <= 0.0) {
            return 0.0;
        }
        return variance;
    }
}

