package weka.estimators;

import java.io.Serializable;
import java.util.Random;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Matrix;

/* loaded from: input_file:weka/estimators/MultivariateNormalHMMEstimator.class */
public class MultivariateNormalHMMEstimator extends AbstractHMMEstimator implements HMMEstimator, Serializable {
    private static final long serialVersionUID = -1123497102759147327L;
    protected boolean m_Tied;
    protected int m_CovarianceType;
    protected MultivariateNormalEstimator[] m_outputEstimators;

    public boolean isTied() {
        return this.m_Tied;
    }

    public void setTied(boolean z) {
        this.m_Tied = z;
    }

    public int getCovarianceType() {
        return this.m_CovarianceType;
    }

    public void setCovarianceType(int i) {
        this.m_CovarianceType = i;
        if (this.m_outputEstimators != null) {
            for (int i2 = 0; i2 < getNumStates(); i2++) {
                this.m_outputEstimators[i2].setCovarianceType(getCovarianceType());
            }
        }
    }

    @Override // weka.estimators.AbstractHMMEstimator, weka.estimators.HMMEstimator
    public int getOutputDimension() {
        return this.m_outputEstimators[0].getDimension();
    }

    public MultivariateNormalHMMEstimator() {
        this.m_Tied = true;
        this.m_CovarianceType = 0;
    }

    public MultivariateNormalHMMEstimator(int i, boolean z) {
        super(i, z);
        this.m_Tied = true;
        this.m_CovarianceType = 0;
        this.m_outputEstimators = new MultivariateNormalEstimator[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.m_outputEstimators[i2] = new MultivariateNormalEstimator();
            this.m_outputEstimators[i2].setCovarianceType(getCovarianceType());
        }
    }

    public MultivariateNormalHMMEstimator(MultivariateNormalHMMEstimator multivariateNormalHMMEstimator) throws Exception {
        super(multivariateNormalHMMEstimator);
        this.m_Tied = true;
        this.m_CovarianceType = 0;
        this.m_outputEstimators = new MultivariateNormalEstimator[multivariateNormalHMMEstimator.getNumStates()];
        for (int i = 0; i < this.m_outputEstimators.length; i++) {
            this.m_outputEstimators[i] = new MultivariateNormalEstimator(multivariateNormalHMMEstimator.m_outputEstimators[i]);
        }
    }

    public void copyOutputParameters(MultivariateNormalHMMEstimator multivariateNormalHMMEstimator) throws Exception {
        setCovarianceType(multivariateNormalHMMEstimator.getCovarianceType());
        setTied(multivariateNormalHMMEstimator.isTied());
        this.m_outputEstimators = new MultivariateNormalEstimator[multivariateNormalHMMEstimator.getNumStates()];
        for (int i = 0; i < this.m_outputEstimators.length; i++) {
            this.m_outputEstimators[i] = new MultivariateNormalEstimator(multivariateNormalHMMEstimator.m_outputEstimators[i]);
        }
    }

    @Override // weka.estimators.AbstractHMMEstimator, weka.estimators.HMMEstimator
    public void setNumStates(int i) {
        super.setNumStates(i);
        this.m_outputEstimators = new MultivariateNormalEstimator[this.m_NumStates];
        for (int i2 = 0; i2 < this.m_NumStates; i2++) {
            this.m_outputEstimators[i2] = new MultivariateNormalEstimator();
            this.m_outputEstimators[i2].setCovarianceType(getCovarianceType());
        }
    }

    public void setState0Probabilities(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            this.m_state0Estimator.addValue(i, dArr[i]);
        }
    }

    public void setStateProbabilities(double[][] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                this.m_stateEstimators[i].addValue(i2, dArr[i][i2]);
            }
        }
    }

    public void setOutputMeans(DoubleVector[] doubleVectorArr) {
        for (int i = 0; i < doubleVectorArr.length; i++) {
            this.m_outputEstimators[i].setMean(doubleVectorArr[i]);
        }
    }

    public void setOutputMean(int i, DoubleVector doubleVector) {
        this.m_outputEstimators[i].setMean(doubleVector);
    }

    public void setOutputVariances(Matrix[] matrixArr) {
        for (int i = 0; i < matrixArr.length; i++) {
            this.m_outputEstimators[i].setVariance(matrixArr[i]);
        }
    }

    public void setOutputVariance(int i, Matrix matrix) {
        this.m_outputEstimators[i].setVariance(matrix);
    }

    @Override // weka.estimators.HMMEstimator
    public int Sample(Instances instances, int i, Random random) {
        int nextInt;
        do {
            nextInt = random.nextInt(getNumStates());
        } while (random.nextDouble() > this.m_stateEstimators[i].getProbability(nextInt));
        DoubleVector sample = this.m_outputEstimators[nextInt].sample();
        instances.add(new DenseInstance(sample.size()));
        Instance lastInstance = instances.lastInstance();
        for (int i2 = 0; i2 < sample.size(); i2++) {
            lastInstance.setValue(i2, sample.get(i2));
        }
        return nextInt;
    }

    @Override // weka.estimators.HMMEstimator
    public int Sample0(Instances instances, Random random) {
        int nextInt;
        do {
            nextInt = random.nextInt(getNumStates());
        } while (random.nextDouble() > this.m_state0Estimator.getProbability(nextInt));
        DoubleVector sample = this.m_outputEstimators[nextInt].sample();
        instances.add(new DenseInstance(sample.size()));
        Instance lastInstance = instances.lastInstance();
        for (int i = 0; i < sample.size(); i++) {
            lastInstance.setValue(i, sample.get(i));
        }
        return nextInt;
    }

    @Override // weka.estimators.HMMEstimator
    public void addValue(double d, double d2, DoubleVector doubleVector, double d3) {
        this.m_stateEstimators[(int) d].addValue(d2, d3);
        this.m_outputEstimators[(int) d2].addValue(doubleVector, d3);
    }

    @Override // weka.estimators.HMMEstimator
    public void addValue0(double d, DoubleVector doubleVector, double d2) {
        this.m_state0Estimator.addValue(d, d2);
        this.m_outputEstimators[(int) d].addValue(doubleVector, d2);
    }

    @Override // weka.estimators.HMMEstimator
    public double getProbability(double d, double d2, DoubleVector doubleVector) throws Exception {
        double probability = this.m_stateEstimators[(int) d].getProbability(d2) * this.m_outputEstimators[(int) d2].getProbability(doubleVector);
        if (Double.isInfinite(probability) || Double.isNaN(probability)) {
            throw new Exception("Calculated probability is NaN");
        }
        return probability;
    }

    @Override // weka.estimators.HMMEstimator
    public double getProbability0(double d, DoubleVector doubleVector) throws Exception {
        return this.m_state0Estimator.getProbability(d) * this.m_outputEstimators[(int) d].getProbability(doubleVector);
    }

    @Override // weka.estimators.HMMEstimator
    public void addValue(double d, double d2, double d3, double d4) throws Exception {
        if (getOutputDimension() != 1) {
            throw new Exception("Trying to get the probability of a multivariate output with a single value");
        }
        addValue(d, d2, new DoubleVector(1, d3), d4);
    }

    @Override // weka.estimators.HMMEstimator
    public void addValue0(double d, double d2, double d3) throws Exception {
        if (getOutputDimension() != 1) {
            throw new Exception("Trying to get the probability of a multivariate output with a single value");
        }
        addValue0(d, new DoubleVector(1, d2), d3);
    }

    @Override // weka.estimators.HMMEstimator
    public double getProbability(double d, double d2, double d3) throws Exception {
        if (getOutputDimension() == 1) {
            return getProbability(d, d2, new DoubleVector(1, d3));
        }
        throw new Exception("Trying to get the probability of a multivariate output with a single value");
    }

    @Override // weka.estimators.HMMEstimator
    public double getProbability0(double d, double d2) throws Exception {
        if (getOutputDimension() == 1) {
            return getProbability0(d, new DoubleVector(1, d2));
        }
        throw new Exception("Trying to get the probability of a multivariate output with a single value");
    }

    @Override // weka.estimators.AbstractHMMEstimator
    public String toString() {
        String str = "MultivariateNormalHMMEstimator\n" + super.toString();
        for (int i = 0; i < this.m_outputEstimators.length; i++) {
            str = str + "Output Estimator, state " + i + "\n" + this.m_outputEstimators[i].toString() + "\n";
        }
        return str;
    }

    public String getRevision() {
        return null;
    }

    @Override // weka.estimators.HMMEstimator
    public void calculateParameters() throws Exception {
        if (isTied()) {
            MultivariateNormalEstimator.calculateTiedParameters(this.m_outputEstimators);
            return;
        }
        for (int i = 0; i < this.m_outputEstimators.length; i++) {
            this.m_outputEstimators[i].calculateParameters();
        }
    }
}
