package weka.estimators;

import java.io.Serializable;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Matrix;
import weka.core.matrix.SerializableDoubleVector;

/* loaded from: input_file:weka/estimators/MultivariateNormalEstimator.class */
public class MultivariateNormalEstimator implements Serializable {
    private static final long serialVersionUID = 5472266864312430693L;
    public static final int COVARIANCE_FULL = 0;
    public static final int COVARIANCE_DIAGONAL = 1;
    public static final int COVARIANCE_SPHERICAL = 2;
    double m_SumOfWeights;
    SerializableDoubleVector m_SumOfValues;
    Matrix m_SumOfSquareValues;
    int m_NumObservations;
    SerializableDoubleVector m_Mean;
    Matrix m_Var;
    Matrix m_InvVar;
    Matrix m_CholeskyL;
    double m_DetVar;
    protected int m_CovarianceType;
    boolean m_Dirty;

    public int getCovarianceType() {
        return this.m_CovarianceType;
    }

    public void setCovarianceType(int i) {
        this.m_CovarianceType = i;
    }

    public MultivariateNormalEstimator() {
        this.m_CovarianceType = 0;
        this.m_Dirty = false;
        this.m_Dirty = true;
    }

    public MultivariateNormalEstimator(MultivariateNormalEstimator multivariateNormalEstimator) throws Exception {
        this.m_CovarianceType = 0;
        this.m_Dirty = false;
        if (multivariateNormalEstimator.m_Dirty) {
            multivariateNormalEstimator.calculateParameters();
        }
        this.m_Mean = new SerializableDoubleVector(multivariateNormalEstimator.m_Mean.copy());
        this.m_Var = multivariateNormalEstimator.m_Var.copy();
        this.m_InvVar = multivariateNormalEstimator.m_InvVar.copy();
        this.m_CholeskyL = multivariateNormalEstimator.m_CholeskyL.copy();
        this.m_DetVar = multivariateNormalEstimator.m_DetVar;
        this.m_CovarianceType = multivariateNormalEstimator.m_CovarianceType;
        this.m_Dirty = false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getDimension() {
        if (this.m_Mean != null) {
            return this.m_Mean.size();
        }
        return 0;
    }

    void init(int i) {
        this.m_SumOfWeights = 0.0d;
        this.m_SumOfValues = new SerializableDoubleVector(new DoubleVector(i, 0.0d));
        this.m_SumOfSquareValues = new Matrix(i, i, 0.0d);
    }

    void calculateVarianceFull() {
        this.m_Var = new Matrix(this.m_Mean.size(), this.m_Mean.size());
        for (int i = 0; i < this.m_Mean.size(); i++) {
            for (int i2 = 0; i2 < this.m_Mean.size(); i2++) {
                this.m_Var.set(i, i2, this.m_SumOfSquareValues.get(i, i2) / this.m_SumOfWeights);
            }
        }
        for (int i3 = 0; i3 < this.m_Mean.size(); i3++) {
            for (int i4 = 0; i4 < this.m_Mean.size(); i4++) {
                this.m_Var.set(i3, i4, this.m_Var.get(i3, i4) - (this.m_Mean.get(i3) * this.m_Mean.get(i4)));
            }
        }
    }

    void calculateVarianceDiagonal() {
        this.m_Var = new Matrix(this.m_Mean.size(), this.m_Mean.size());
        for (int i = 0; i < this.m_Mean.size(); i++) {
            this.m_Var.set(i, i, this.m_SumOfSquareValues.get(i, i) / this.m_SumOfWeights);
        }
        for (int i2 = 0; i2 < this.m_Mean.size(); i2++) {
            this.m_Var.set(i2, i2, this.m_Var.get(i2, i2) - (this.m_Mean.get(i2) * this.m_Mean.get(i2)));
        }
    }

    void calculateVarianceSpherical() {
        double d = 0.0d;
        for (int i = 0; i < this.m_Mean.size(); i++) {
            d += this.m_SumOfSquareValues.get(i, i);
        }
        double d2 = d / this.m_SumOfWeights;
        for (int i2 = 0; i2 < this.m_Mean.size(); i2++) {
            d2 -= this.m_Mean.get(i2) * this.m_Mean.get(i2);
        }
        this.m_Var = Matrix.identity(this.m_Mean.size(), this.m_Mean.size());
        this.m_Var.timesEquals(d2 / this.m_Mean.size());
    }

    public void calculateParameters() throws Exception {
        this.m_Dirty = false;
        if (this.m_SumOfWeights > 1.0E-5d) {
            this.m_Mean = new SerializableDoubleVector(this.m_SumOfValues.times(1.0d / this.m_SumOfWeights));
            switch (getCovarianceType()) {
                case COVARIANCE_FULL /* 0 */:
                    calculateVarianceFull();
                    break;
                case COVARIANCE_DIAGONAL /* 1 */:
                    calculateVarianceDiagonal();
                    break;
                case COVARIANCE_SPHERICAL /* 2 */:
                    calculateVarianceSpherical();
                    break;
                default:
                    throw new Exception("Unhandled covariance type");
            }
            this.m_DetVar = this.m_Var.det();
            if (this.m_DetVar < 1.0E-200d) {
                System.err.println("Covariance matrix has zero determinant");
            } else {
                this.m_InvVar = this.m_Var.inverse();
                this.m_CholeskyL = this.m_Var.chol().getL();
            }
        }
    }

    public static void calculateTiedParameters(MultivariateNormalEstimator[] multivariateNormalEstimatorArr) throws Exception {
        if (multivariateNormalEstimatorArr.length == 0) {
            return;
        }
        Matrix matrix = new Matrix(multivariateNormalEstimatorArr[0].m_Mean.size(), multivariateNormalEstimatorArr[0].m_Mean.size(), 0.0d);
        double d = 0.0d;
        for (int i = 0; i < multivariateNormalEstimatorArr.length; i++) {
            multivariateNormalEstimatorArr[i].calculateParameters();
            matrix.plusEquals(multivariateNormalEstimatorArr[i].m_Var.times(multivariateNormalEstimatorArr[i].m_SumOfWeights));
            d += multivariateNormalEstimatorArr[i].m_SumOfWeights;
        }
        matrix.timesEquals(1.0d / d);
        for (MultivariateNormalEstimator multivariateNormalEstimator : multivariateNormalEstimatorArr) {
            multivariateNormalEstimator.m_Var = matrix.copy();
        }
    }

    public DoubleVector getMean() {
        return this.m_Mean;
    }

    public void setMean(DoubleVector doubleVector) {
        this.m_Mean = new SerializableDoubleVector(doubleVector.copy());
    }

    public Matrix getVariance() {
        return this.m_Var;
    }

    public void setVariance(Matrix matrix) {
        this.m_Var = matrix.copy();
        this.m_InvVar = this.m_Var.inverse();
        this.m_DetVar = this.m_Var.det();
        this.m_CholeskyL = this.m_Var.chol().getL();
    }

    public void addValue(DoubleVector doubleVector, double d) {
        if (d == 0.0d) {
            return;
        }
        if (this.m_SumOfValues == null) {
            init(doubleVector.size());
        }
        this.m_SumOfWeights += d;
        this.m_SumOfValues.plusEquals(doubleVector.times(d));
        this.m_NumObservations++;
        for (int i = 0; i < doubleVector.size(); i++) {
            for (int i2 = 0; i2 < doubleVector.size(); i2++) {
                this.m_SumOfSquareValues.set(i, i2, this.m_SumOfSquareValues.get(i, i2) + (doubleVector.get(i) * doubleVector.get(i2) * d));
            }
        }
        this.m_Dirty = true;
    }

    public double getProbability(DoubleVector doubleVector) throws Exception {
        if (this.m_Dirty) {
            calculateParameters();
        }
        if (this.m_DetVar < 1.0E-200d) {
            return 0.0d;
        }
        double pow = ((1.0d / Math.pow(6.283185307179586d, this.m_Mean.size() / 2.0d)) * 1.0d) / Math.sqrt(this.m_DetVar);
        double d = 0.0d;
        for (int i = 0; i < this.m_Mean.size(); i++) {
            for (int i2 = 0; i2 < this.m_Mean.size(); i2++) {
                d += (doubleVector.get(i) - this.m_Mean.get(i)) * this.m_InvVar.get(i, i2) * (doubleVector.get(i2) - this.m_Mean.get(i2));
            }
        }
        double exp = pow * Math.exp((-0.5d) * d);
        if (Double.isInfinite(exp) || Double.isNaN(exp)) {
            throw new Exception("Calculated probability is NaN");
        }
        return exp;
    }

    public DoubleVector boxMuller() {
        DoubleVector random;
        double sum2;
        do {
            random = DoubleVector.random(2);
            random.timesEquals(2.0d);
            random.minusEquals(1.0d);
            sum2 = random.sum2();
        } while (sum2 > 1.0d);
        for (int i = 0; i < random.size(); i++) {
            random.set(i, random.get(i) * Math.sqrt(((-2.0d) * Math.log(sum2)) / sum2));
        }
        return random;
    }

    public DoubleVector sample() {
        DoubleVector doubleVector = new DoubleVector(this.m_Mean.size());
        for (int i = 0; i < doubleVector.size() / 2; i++) {
            DoubleVector boxMuller = boxMuller();
            doubleVector.set(2 * i, boxMuller.get(0));
            doubleVector.set((2 * i) + 1, boxMuller.get(1));
        }
        if (doubleVector.size() % 2 == 1) {
            doubleVector.set(doubleVector.size() - 1, boxMuller().get(0));
        }
        DoubleVector doubleVector2 = new DoubleVector(this.m_Mean.size());
        for (int i2 = 0; i2 < doubleVector2.size(); i2++) {
            for (int i3 = 0; i3 < doubleVector2.size(); i3++) {
                doubleVector2.set(i2, doubleVector2.get(i2) + (this.m_CholeskyL.get(i2, i3) * doubleVector.get(i3)));
            }
        }
        doubleVector2.plusEquals(this.m_Mean);
        return doubleVector2;
    }

    public String toString() {
        String str = "";
        switch (getCovarianceType()) {
            case COVARIANCE_FULL /* 0 */:
                str = this.m_Var.toString();
                break;
            case COVARIANCE_DIAGONAL /* 1 */:
                for (int i = 0; i < this.m_Mean.size(); i++) {
                    str = str + this.m_Var.get(i, i) + " ";
                }
                break;
            case COVARIANCE_SPHERICAL /* 2 */:
                str = str + this.m_Var.get(0, 0);
                break;
        }
        return "Mean\n" + this.m_Mean.toString() + "\nCovariance\n" + str;
    }
}
