package weka.classifiers.bayes;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Matrix;
import weka.estimators.DiscreteHMMEstimator;
import weka.estimators.HMMEstimator;
import weka.estimators.MultivariateNormalEstimator;
import weka.estimators.MultivariateNormalHMMEstimator;

/* loaded from: input_file:weka/classifiers/bayes/HMM.class */
public class HMM extends RandomizableClassifier implements OptionHandler, MultiInstanceCapabilitiesHandler {
    protected int m_NumOutputs;
    public static final Tag[] TAGS_COVARIANCE_TYPE;
    protected HMMEstimator[] estimators;
    private static final long serialVersionUID = 1959669739718119361L;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected int m_NumStates = 6;
    protected int m_OutputDimension = 1;
    protected boolean m_Numeric = false;
    protected double m_IterationCutoff = 0.01d;
    protected int m_SeqAttr = -1;
    protected Random m_rand = null;
    protected double minScale = 1.0E-200d;
    protected boolean m_RandomStateInitializers = false;
    protected boolean m_Tied = false;
    protected int m_CovarianceType = 0;
    protected boolean m_LeftRight = false;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/bayes/HMM$ProbabilityTooSmallException.class */
    public class ProbabilityTooSmallException extends Exception {
        private static final long serialVersionUID = -2706223192260478060L;

        ProbabilityTooSmallException(String str) {
            super(str);
        }
    }

    public int getSequenceAttribute() {
        return this.m_SeqAttr;
    }

    public boolean isRandomStateInitializers() {
        return this.m_RandomStateInitializers;
    }

    public void setRandomStateInitializers(boolean z) {
        this.m_RandomStateInitializers = z;
    }

    public boolean isTied() {
        return this.m_Tied;
    }

    public void setTied(boolean z) {
        this.m_Tied = z;
    }

    public SelectedTag getCovarianceType() {
        return new SelectedTag(this.m_CovarianceType, TAGS_COVARIANCE_TYPE);
    }

    public void setCovarianceType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_COVARIANCE_TYPE) {
            this.m_CovarianceType = selectedTag.getSelectedTag().getID();
        }
    }

    public boolean isLeftRight() {
        return this.m_LeftRight;
    }

    public void setLeftRight(boolean z) {
        this.m_LeftRight = z;
    }

    public int getOutputDimension() {
        return this.m_OutputDimension;
    }

    public void setOutputDimension(int i) {
        this.m_OutputDimension = i;
    }

    public boolean isNumeric() {
        return this.m_Numeric;
    }

    public void setNumeric(boolean z) {
        this.m_Numeric = z;
        if (this.m_Numeric) {
            setIterationCutoff(1.0E-4d);
        } else {
            setIterationCutoff(0.01d);
        }
    }

    public double getIterationCutoff() {
        return this.m_IterationCutoff;
    }

    public void setIterationCutoff(double d) {
        this.m_IterationCutoff = d;
    }

    public int getNumClasses() {
        if (this.estimators == null) {
            return 0;
        }
        return this.estimators.length;
    }

    public int getNumStates() {
        return this.m_NumStates;
    }

    public void setNumStates(int i) {
        this.m_NumStates = i;
    }

    public int getNumOutputs() {
        return this.m_NumOutputs;
    }

    public void setNumOutputs(int i) {
        this.m_NumOutputs = i;
    }

    public void setProbability0(int i, double d, DoubleVector doubleVector, double d2) {
        this.estimators[i].addValue0(d, doubleVector, d2);
    }

    public void setProbability(int i, double d, double d2, DoubleVector doubleVector, double d3) {
        this.estimators[i].addValue(d, d2, doubleVector, d3);
    }

    protected double likelihoodFromScales(double[] dArr) {
        double d;
        double log;
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (Math.abs(dArr[i]) > 1.0E-32d) {
                d = d2;
                log = Math.log(dArr[i]);
            } else {
                d = d2;
                log = Math.log(1.0E-32d);
            }
            d2 = d + log;
        }
        return d2;
    }

    protected double[] forward(HMMEstimator hMMEstimator, Instances instances, double[][] dArr) throws Exception {
        double[] dArr2 = new double[instances.numInstances()];
        dArr2[0] = 0.0d;
        DoubleVector doubleVector = new DoubleVector(instances.instance(0).numAttributes());
        for (int i = 0; i < instances.instance(0).numAttributes(); i++) {
            doubleVector.set(i, instances.instance(0).value(i));
        }
        for (int i2 = 0; i2 < this.m_NumStates; i2++) {
            dArr[0][i2] = hMMEstimator.getProbability0(i2, doubleVector);
            dArr2[0] = dArr2[0] + dArr[0][i2];
        }
        if (Math.abs(dArr2[0]) <= this.minScale) {
            throw new ProbabilityTooSmallException("time step 0 probability " + dArr2[0]);
        }
        for (int i3 = 0; i3 < this.m_NumStates; i3++) {
            double[] dArr3 = dArr[0];
            int i4 = i3;
            dArr3[i4] = dArr3[i4] / dArr2[0];
        }
        for (int i5 = 1; i5 < instances.numInstances(); i5++) {
            DoubleVector doubleVector2 = new DoubleVector(instances.instance(i5).numAttributes());
            for (int i6 = 0; i6 < instances.instance(i5).numAttributes(); i6++) {
                doubleVector2.set(i6, instances.instance(i5).value(i6));
            }
            dArr2[i5] = 0.0d;
            for (int i7 = 0; i7 < this.m_NumStates; i7++) {
                dArr[i5][i7] = 0.0d;
                for (int i8 = 0; i8 < this.m_NumStates; i8++) {
                    double[] dArr4 = dArr[i5];
                    int i9 = i7;
                    dArr4[i9] = dArr4[i9] + (dArr[i5 - 1][i8] * hMMEstimator.getProbability(i8, i7, doubleVector2));
                }
                int i10 = i5;
                dArr2[i10] = dArr2[i10] + dArr[i5][i7];
            }
            if (Math.abs(dArr2[i5]) <= this.minScale) {
                throw new ProbabilityTooSmallException("time step " + i5 + " probability " + dArr2[i5]);
            }
            for (int i11 = 0; i11 < this.m_NumStates; i11++) {
                double[] dArr5 = dArr[i5];
                int i12 = i11;
                dArr5[i12] = dArr5[i12] / dArr2[i5];
            }
        }
        return dArr2;
    }

    protected double forward(HMMEstimator hMMEstimator, Instances instances) throws Exception {
        return likelihoodFromScales(forward(hMMEstimator, instances, new double[instances.numInstances()][this.m_NumStates]));
    }

    protected double[] forwardBackward(HMMEstimator hMMEstimator, Instances instances, double[][] dArr, double[][] dArr2) throws Exception {
        double[] forward = forward(hMMEstimator, instances, dArr);
        for (int i = 0; i < getNumStates(); i++) {
            dArr2[instances.numInstances() - 1][i] = 1.0d;
            if (Double.isInfinite(dArr2[instances.numInstances() - 1][i]) || Double.isNaN(dArr2[instances.numInstances() - 1][i])) {
                throw new Exception("Beta for the final timestep is NaN");
            }
        }
        for (int numInstances = instances.numInstances() - 2; numInstances >= 0; numInstances--) {
            for (int i2 = 0; i2 < getNumStates(); i2++) {
                dArr2[numInstances][i2] = 0.0d;
                for (int i3 = 0; i3 < getNumStates(); i3++) {
                    DoubleVector doubleVector = new DoubleVector(instances.instance(numInstances + 1).numAttributes());
                    for (int i4 = 0; i4 < instances.instance(numInstances + 1).numAttributes(); i4++) {
                        doubleVector.set(i4, instances.instance(numInstances + 1).value(i4));
                    }
                    double probability = hMMEstimator.getProbability(i2, i3, doubleVector);
                    double[] dArr3 = dArr2[numInstances];
                    int i5 = i2;
                    dArr3[i5] = dArr3[i5] + (dArr2[numInstances + 1][i3] * probability);
                    if (Double.isInfinite(dArr2[numInstances][i2]) || Double.isNaN(dArr2[numInstances][i2])) {
                        throw new Exception("Unscaled Beta is NaN");
                    }
                }
            }
            if (Math.abs(forward[numInstances + 1]) <= this.minScale) {
                throw new ProbabilityTooSmallException("time step " + (numInstances + 1) + " probabilit " + forward[numInstances + 1]);
            }
            for (int i6 = 0; i6 < getNumStates(); i6++) {
                double[] dArr4 = dArr2[numInstances];
                int i7 = i6;
                dArr4[i7] = dArr4[i7] / forward[numInstances + 1];
                if (Double.isInfinite(dArr2[numInstances][i6]) || Double.isNaN(dArr2[numInstances][i6])) {
                    throw new Exception("Scaled Beta is NaN");
                }
            }
        }
        return forward;
    }

    protected double forwardBackward(HMMEstimator hMMEstimator, Instances instances) throws Exception {
        return likelihoodFromScales(forwardBackward(hMMEstimator, instances, new double[instances.numInstances()][this.m_NumStates], new double[instances.numInstances()][this.m_NumStates]));
    }

    public double[][] probabilitiesForInstance(int i, Instance instance) throws Exception {
        Instances relationalValue = instance.relationalValue(this.m_SeqAttr);
        double[][] dArr = new double[relationalValue.numInstances()][this.m_NumStates];
        double[][] dArr2 = new double[relationalValue.numInstances()][this.m_NumStates];
        double[][] dArr3 = new double[relationalValue.numInstances()][this.m_NumStates];
        Math.exp(likelihoodFromScales(forwardBackward(this.estimators[i], relationalValue, dArr, dArr2)));
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            for (int i3 = 0; i3 < dArr3[i2].length; i3++) {
                dArr3[i2][i3] = dArr[i2][i3] * dArr2[i2][i3];
            }
        }
        return dArr3;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.estimators == null) {
            return new double[]{0.5d, 0.5d};
        }
        double[] dArr = new double[this.estimators.length];
        double d = 0.0d;
        if (this.m_SeqAttr < 0) {
            for (int i = 0; i < this.estimators.length; i++) {
                dArr[i] = 1.0d;
                d += dArr[i];
            }
        } else {
            Instances relationalValue = instance.relationalValue(this.m_SeqAttr);
            for (int i2 = 0; i2 < this.estimators.length; i2++) {
                try {
                    dArr[i2] = Math.exp(forward(this.estimators[i2], relationalValue));
                } catch (ProbabilityTooSmallException e) {
                    dArr[i2] = 0.0d;
                }
                d += dArr[i2];
            }
        }
        if (Math.abs(d) > 1.0E-7d) {
            for (int i3 = 0; i3 < this.estimators.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
            }
        }
        return dArr;
    }

    protected Object clone() throws CloneNotSupportedException {
        return super/*java.lang.Object*/.clone();
    }

    public String globalInfo() {
        return "Class for a Hidden Markov Model classifier.";
    }

    public String[] getOptions() {
        String[] strArr = new String[10];
        int i = 0 + 1;
        strArr[0] = "-S " + getNumStates();
        int i2 = i + 1;
        strArr[i] = "-I " + getIterationCutoff();
        switch (this.m_CovarianceType) {
            case MultivariateNormalEstimator.COVARIANCE_FULL /* 0 */:
                i2++;
                strArr[i2] = "-C FULL";
                break;
            case MultivariateNormalEstimator.COVARIANCE_DIAGONAL /* 1 */:
                i2++;
                strArr[i2] = "-C DIAGONAL";
                break;
            case MultivariateNormalEstimator.COVARIANCE_SPHERICAL /* 2 */:
                i2++;
                strArr[i2] = "-C SPHERICAL";
                break;
        }
        int i3 = i2;
        int i4 = i2 + 1;
        strArr[i3] = "-D" + isTied();
        int i5 = i4 + 1;
        strArr[i4] = "-L" + isLeftRight();
        int i6 = i5 + 1;
        strArr[i5] = "-R" + isRandomStateInitializers();
        while (i6 < strArr.length) {
            int i7 = i6;
            i6++;
            strArr[i7] = "";
        }
        return strArr;
    }

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(4);
        vector.addElement(new Option("\tStates: number of HMM states to use\n", "S", 1, "-S"));
        vector.addElement(new Option("\tIteration Cutoff: the proportional minimum change of likelihood\n\tat which to stop the EM iteractions ", "I", 1, "-I"));
        vector.addElement(new Option("\tCovariance Type: whether the covariances of gaussian\n\toutputs should be full matrices or limited to diagonal\n\tor spherical matrices ", "C", 1, "-C"));
        vector.addElement(new Option("\tTied Covariance: whether the covariances of gaussian\n\toutputs are tied to be the same across all outputs ", "D", 1, "-D"));
        vector.addElement(new Option("\tLeft Right: whether the state transitions are constrained\n\tto go only to the next state in numerical order ", "L", 1, "-L"));
        vector.addElement(new Option("\tRandom Initialisation: whether the state transition probabilities are intialized randomly\n\t(if this is false they are initialised by performing a k-means clustering on the data) ", "R", 1, "-R"));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('I', strArr);
        if (option.length() != 0) {
            setIterationCutoff(Double.parseDouble(option));
        }
        String option2 = Utils.getOption('S', strArr);
        if (option2.length() != 0) {
            setNumStates(Integer.parseInt(option2));
        }
        String option3 = Utils.getOption('C', strArr);
        if (option3.length() != 0) {
            if (option3.equals("FULL")) {
                setCovarianceType(new SelectedTag(0, TAGS_COVARIANCE_TYPE));
            }
            if (option3.equals("DIAGONAL")) {
                setCovarianceType(new SelectedTag(1, TAGS_COVARIANCE_TYPE));
            }
            if (option3.equals("SPHERICAL")) {
                setCovarianceType(new SelectedTag(2, TAGS_COVARIANCE_TYPE));
            }
        }
        if (Utils.getFlag('D', strArr)) {
            setTied(true);
        }
        if (Utils.getFlag('L', strArr)) {
            setLeftRight(true);
        }
        if (Utils.getFlag('R', strArr)) {
            setRandomStateInitializers(true);
        }
        Utils.checkForRemainingOptions(strArr);
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = new Capabilities(this);
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    protected double EMStep(Instances instances) throws Exception {
        double d = 0.0d;
        boolean z = false;
        HMMEstimator[] hMMEstimatorArr = new HMMEstimator[instances.numClasses()];
        for (int i = 0; i < instances.numClasses(); i++) {
            if (isNumeric()) {
                MultivariateNormalHMMEstimator multivariateNormalHMMEstimator = new MultivariateNormalHMMEstimator(getNumStates(), false);
                multivariateNormalHMMEstimator.copyOutputParameters((MultivariateNormalHMMEstimator) this.estimators[i]);
                hMMEstimatorArr[i] = multivariateNormalHMMEstimator;
            } else {
                hMMEstimatorArr[i] = new DiscreteHMMEstimator(getNumStates(), getNumOutputs(), false);
            }
        }
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < instances.numInstances(); i4++) {
            Instance instance = instances.instance(i4);
            if (!instance.isMissing(this.m_SeqAttr) && !instance.classIsMissing()) {
                Instances relationalValue = instance.relationalValue(this.m_SeqAttr);
                double[][] dArr = new double[relationalValue.numInstances()][this.m_NumStates];
                double[][] dArr2 = new double[relationalValue.numInstances()][this.m_NumStates];
                int value = (int) instance.value(instances.classIndex());
                HMMEstimator hMMEstimator = this.estimators[value];
                try {
                    double[] forwardBackward = forwardBackward(hMMEstimator, relationalValue, dArr, dArr2);
                    double likelihoodFromScales = likelihoodFromScales(forwardBackward);
                    d += likelihoodFromScales;
                    Math.exp(likelihoodFromScales);
                    double d2 = 0.0d;
                    DoubleVector doubleVector = new DoubleVector(relationalValue.instance(0).numAttributes());
                    for (int i5 = 0; i5 < relationalValue.instance(0).numAttributes(); i5++) {
                        doubleVector.set(i5, relationalValue.instance(0).value(i5));
                    }
                    double[][] dArr3 = new double[getNumStates()][getNumStates()];
                    for (int i6 = 0; i6 < getNumStates(); i6++) {
                        dArr3[0][i6] = dArr[0][i6] * dArr2[0][i6];
                        d2 += dArr3[0][i6];
                    }
                    for (int i7 = 0; i7 < getNumStates(); i7++) {
                        if (d2 > this.minScale) {
                            hMMEstimatorArr[value].addValue0(i7, doubleVector, dArr3[0][i7] / d2);
                        }
                        if (Double.isInfinite(dArr3[0][i7]) || Double.isNaN(dArr3[0][i7])) {
                            throw new Exception("Output of the forward backward algorithm gives a NaN");
                        }
                    }
                    for (int i8 = 1; i8 < relationalValue.numInstances(); i8++) {
                        double d3 = 0.0d;
                        DoubleVector doubleVector2 = new DoubleVector(relationalValue.instance(i8).numAttributes());
                        for (int i9 = 0; i9 < relationalValue.instance(i8).numAttributes(); i9++) {
                            doubleVector2.set(i9, relationalValue.instance(i8).value(i9));
                        }
                        for (int i10 = 0; i10 < getNumStates(); i10++) {
                            for (int i11 = 0; i11 < getNumStates(); i11++) {
                                dArr3[i11][i10] = dArr[i8 - 1][i11] * hMMEstimator.getProbability(i11, i10, doubleVector2) * dArr2[i8][i10] * forwardBackward[i8];
                                d3 += dArr3[i11][i10];
                            }
                        }
                        for (int i12 = 0; i12 < getNumStates(); i12++) {
                            for (int i13 = 0; i13 < getNumStates(); i13++) {
                                if (d3 > this.minScale) {
                                    if (value == 0 && i12 == 1 && dArr3[i13][i12] / d3 > 0.01d) {
                                        i3++;
                                    }
                                    if (value == 0 && i12 == 0 && dArr3[i13][i12] / d3 > 0.01d) {
                                        i2++;
                                    }
                                    hMMEstimatorArr[value].addValue(i13, i12, doubleVector2, dArr3[i13][i12] / d3);
                                }
                                if (Double.isInfinite(dArr3[i13][i12]) || Double.isNaN(dArr3[i13][i12])) {
                                    throw new Exception("Output of the forward backward algorithm gives a NaN");
                                }
                            }
                        }
                    }
                    z = true;
                } catch (ProbabilityTooSmallException e) {
                }
            }
        }
        if (!z) {
            throw new Exception("Failed to update on EM step");
        }
        this.estimators = hMMEstimatorArr;
        for (int i14 = 0; i14 < this.estimators.length; i14++) {
            this.estimators[i14].calculateParameters();
        }
        return d / instances.numInstances();
    }

    public void initEstimators(int i, Instances instances) throws Exception {
        if (isNumeric()) {
            initEstimatorsMultivariateNormal(i, (double[][]) null, (double[][][]) null, (DoubleVector[][]) null, (Matrix[][]) null, instances);
        } else {
            initEstimatorsUnivariateDiscrete(i, (double[][]) null, (double[][][]) null, (double[][][]) null);
        }
    }

    protected double[][] initState0ProbsUniform(int i) {
        double[][] dArr = new double[i][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                dArr[i2][i3] = 1.0d;
            }
        }
        return dArr;
    }

    protected double[][] initState0ProbsRandom(int i, Random random) {
        double[][] dArr = new double[i][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                dArr[i2][i3] = random.nextInt(100);
            }
        }
        return dArr;
    }

    protected double[][] initState0ProbsLeftRight(int i) {
        double[][] dArr = new double[i][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2][0] = 1.0d;
            for (int i3 = 1; i3 < getNumStates(); i3++) {
                dArr[i2][i3] = 0.0d;
            }
        }
        return dArr;
    }

    protected double[][][] initStateProbsUniform(int i) {
        double[][][] dArr = new double[i][getNumStates()][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                for (int i4 = 0; i4 < getNumStates(); i4++) {
                    if (i3 == i4) {
                        dArr[i2][i3][i4] = 10.0d;
                    } else {
                        dArr[i2][i3][i4] = 1.0d;
                    }
                }
            }
        }
        return dArr;
    }

    protected double[][][] initStateProbsRandom(int i, Random random) {
        double[][][] dArr = new double[i][getNumStates()][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                for (int i4 = 0; i4 < getNumStates(); i4++) {
                    dArr[i2][i3][i4] = random.nextInt(100);
                }
            }
        }
        return dArr;
    }

    protected double[][][] initStateProbsLeftRight(int i) {
        double[][][] dArr = new double[i][getNumStates()][getNumStates()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                for (int i4 = 0; i4 < getNumStates(); i4++) {
                    dArr[i2][i3][i4] = 0.0d;
                }
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < getNumStates() - 1; i6++) {
                dArr[i5][i6][i6] = 90.0d;
                dArr[i5][i6][i6 + 1] = 10.0d;
            }
            dArr[i5][getNumStates() - 1][getNumStates() - 1] = 100.0d;
        }
        return dArr;
    }

    protected double[][][] initDiscreteOutputProbsRandom(int i, Random random) {
        double[][][] dArr = new double[i][getNumStates()][getNumOutputs()];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                for (int i4 = 0; i4 < getNumOutputs(); i4++) {
                    dArr[i2][i3][i4] = random.nextInt(100);
                }
            }
        }
        return dArr;
    }

    protected void initGaussianOutputProbsRandom(int i, DoubleVector[][] doubleVectorArr, Matrix[][] matrixArr) {
        if (doubleVectorArr == null) {
            doubleVectorArr = new DoubleVector[i][getNumStates()];
            for (int i2 = 0; i2 < i; i2++) {
                for (int i3 = 0; i3 < getNumStates(); i3++) {
                    doubleVectorArr[i2][i3] = DoubleVector.random(getOutputDimension());
                }
            }
        }
        if (matrixArr == null) {
            matrixArr = new Matrix[i][getNumStates()];
            for (int i4 = 0; i4 < i; i4++) {
                for (int i5 = 0; i5 < getNumStates(); i5++) {
                    matrixArr[i4][i5] = Matrix.identity(getOutputDimension(), getOutputDimension());
                    matrixArr[i4][i5].timesEquals(10.0d);
                }
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < getNumStates(); i7++) {
                MultivariateNormalHMMEstimator multivariateNormalHMMEstimator = (MultivariateNormalHMMEstimator) this.estimators[i6];
                multivariateNormalHMMEstimator.setOutputMean(i7, doubleVectorArr[i6][i7]);
                multivariateNormalHMMEstimator.setOutputVariance(i7, matrixArr[i6][i7]);
            }
        }
    }

    protected void initGaussianOutputProbsAllData(int i, Instances instances, DoubleVector[][] doubleVectorArr, Matrix[][] matrixArr) throws Exception {
        MultivariateNormalEstimator[] multivariateNormalEstimatorArr = new MultivariateNormalEstimator[i];
        for (int i2 = 0; i2 < i; i2++) {
            multivariateNormalEstimatorArr[i2] = new MultivariateNormalEstimator();
        }
        this.m_SeqAttr = -1;
        this.m_NumOutputs = 0;
        int i3 = 0;
        while (true) {
            if (i3 >= instances.numAttributes()) {
                break;
            }
            Attribute attribute = instances.attribute(i3);
            if (attribute.isRelationValued()) {
                if (attribute.relation().attribute(0).isNominal()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i3) {
                        throw new AssertionError();
                    }
                }
                if (attribute.relation().attribute(0).isNumeric()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i3) {
                        throw new AssertionError();
                    }
                    setNumeric(true);
                }
            } else {
                i3++;
            }
        }
        for (int i4 = 0; i4 < instances.numInstances(); i4++) {
            Instance instance = instances.instance(i4);
            if (!instance.isMissing(this.m_SeqAttr) && !instance.classIsMissing()) {
                Instances relationalValue = instance.relationalValue(this.m_SeqAttr);
                int value = (int) instance.value(instances.classIndex());
                for (int i5 = 0; i5 < relationalValue.numInstances(); i5++) {
                    DoubleVector doubleVector = new DoubleVector(relationalValue.instance(i5).numAttributes());
                    for (int i6 = 0; i6 < relationalValue.instance(i5).numAttributes(); i6++) {
                        doubleVector.set(i6, relationalValue.instance(i5).value(i6));
                    }
                    multivariateNormalEstimatorArr[value].addValue(doubleVector, 1.0d);
                }
            }
        }
        for (int i7 = 0; i7 < i; i7++) {
            multivariateNormalEstimatorArr[i7].calculateParameters();
        }
        for (int i8 = 0; i8 < i; i8++) {
            for (int i9 = 0; i9 < getNumStates(); i9++) {
                MultivariateNormalHMMEstimator multivariateNormalHMMEstimator = (MultivariateNormalHMMEstimator) this.estimators[i8];
                if (doubleVectorArr == null) {
                    multivariateNormalHMMEstimator.setOutputMean(i9, multivariateNormalEstimatorArr[i8].getMean());
                } else {
                    multivariateNormalHMMEstimator.setOutputMean(i9, doubleVectorArr[i8][i9]);
                }
                if (matrixArr == null) {
                    multivariateNormalHMMEstimator.setOutputVariance(i9, multivariateNormalEstimatorArr[i8].getVariance());
                } else {
                    multivariateNormalHMMEstimator.setOutputVariance(i9, matrixArr[i8][i9]);
                }
            }
        }
    }

    protected void initGaussianOutputProbsCluster(int i, Instances instances, DoubleVector[][] doubleVectorArr, Matrix[][] matrixArr) throws Exception {
        this.m_SeqAttr = -1;
        this.m_NumOutputs = 0;
        int i2 = 0;
        while (true) {
            if (i2 >= instances.numAttributes()) {
                break;
            }
            Attribute attribute = instances.attribute(i2);
            if (attribute.isRelationValued()) {
                if (attribute.relation().attribute(0).isNominal()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i2) {
                        throw new AssertionError();
                    }
                }
                if (attribute.relation().attribute(0).isNumeric()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i2) {
                        throw new AssertionError();
                    }
                }
            } else {
                i2++;
            }
        }
        Instances[] instancesArr = new Instances[i];
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            Instance instance = instances.instance(i3);
            if (!instance.isMissing(this.m_SeqAttr) && !instance.classIsMissing()) {
                Instances relationalValue = instance.relationalValue(this.m_SeqAttr);
                int value = (int) instance.value(instances.classIndex());
                if (instancesArr[value] == null) {
                    instancesArr[value] = new Instances(relationalValue, relationalValue.numInstances());
                }
                for (int i4 = 0; i4 < relationalValue.numInstances(); i4++) {
                    instancesArr[value].add(relationalValue.instance(i4));
                }
            }
        }
        SimpleKMeans[] simpleKMeansArr = new SimpleKMeans[i];
        for (int i5 = 0; i5 < i; i5++) {
            simpleKMeansArr[i5] = new SimpleKMeans();
            simpleKMeansArr[i5].setNumClusters(getNumStates());
            simpleKMeansArr[i5].setDisplayStdDevs(true);
            simpleKMeansArr[i5].buildClusterer(instancesArr[i5]);
        }
        for (int i6 = 0; i6 < i; i6++) {
            Instances clusterCentroids = simpleKMeansArr[i6].getClusterCentroids();
            Instances clusterStandardDevs = simpleKMeansArr[i6].getClusterStandardDevs();
            for (int i7 = 0; i7 < getNumStates(); i7++) {
                MultivariateNormalHMMEstimator multivariateNormalHMMEstimator = (MultivariateNormalHMMEstimator) this.estimators[i6];
                if (doubleVectorArr == null) {
                    DoubleVector doubleVector = new DoubleVector(clusterCentroids.instance(i7).numAttributes());
                    for (int i8 = 0; i8 < clusterCentroids.instance(i7).numAttributes(); i8++) {
                        doubleVector.set(i8, clusterCentroids.instance(i7).value(i8));
                    }
                    System.out.println("Mean " + i7 + " " + doubleVector);
                    multivariateNormalHMMEstimator.setOutputMean(i7, doubleVector);
                } else {
                    multivariateNormalHMMEstimator.setOutputMean(i7, doubleVectorArr[i6][i7]);
                }
                if (matrixArr == null) {
                    int numAttributes = clusterStandardDevs.instance(i7).numAttributes();
                    Matrix matrix = new Matrix(numAttributes, numAttributes, 0.0d);
                    for (int i9 = 0; i9 < numAttributes; i9++) {
                        double value2 = clusterCentroids.instance(i7).value(i9);
                        matrix.set(i9, i9, value2 * value2);
                    }
                    multivariateNormalHMMEstimator.setOutputVariance(i7, matrix);
                } else {
                    multivariateNormalHMMEstimator.setOutputVariance(i7, matrixArr[i6][i7]);
                }
            }
        }
    }

    public void initEstimatorsUnivariateDiscrete(int i, double[][] dArr, double[][][] dArr2, double[][][] dArr3) throws Exception {
        this.estimators = new HMMEstimator[i];
        Random random = new Random(getSeed());
        if (dArr == null) {
            dArr = isLeftRight() ? initState0ProbsLeftRight(i) : isRandomStateInitializers() ? initState0ProbsRandom(i, random) : initState0ProbsUniform(i);
        }
        if (dArr2 == null) {
            dArr2 = isLeftRight() ? initStateProbsLeftRight(i) : isRandomStateInitializers() ? initStateProbsRandom(i, random) : initStateProbsUniform(i);
        }
        if (dArr3 == null) {
            dArr3 = initDiscreteOutputProbsRandom(i, random);
        }
        for (int i2 = 0; i2 < i; i2++) {
            this.estimators[i2] = new DiscreteHMMEstimator(getNumStates(), getNumOutputs(), false);
            for (int i3 = 0; i3 < getNumStates(); i3++) {
                for (int i4 = 0; i4 < getNumOutputs(); i4++) {
                    this.estimators[i2].addValue0(i3, i4, 100.0d * dArr[i2][i3] * dArr3[i2][i3][i4]);
                    for (int i5 = 0; i5 < getNumStates(); i5++) {
                        this.estimators[i2].addValue(i5, i3, i4, 100.0d * dArr2[i2][i5][i3] * dArr3[i2][i3][i4]);
                    }
                }
            }
        }
    }

    public void initEstimatorsMultivariateNormal(int i, double[][] dArr, double[][][] dArr2, DoubleVector[][] doubleVectorArr, Matrix[][] matrixArr, Instances instances) throws Exception {
        this.estimators = new HMMEstimator[i];
        Random random = new Random(getSeed());
        if (dArr == null) {
            dArr = isLeftRight() ? initState0ProbsLeftRight(i) : isRandomStateInitializers() ? initState0ProbsRandom(i, random) : initState0ProbsUniform(i);
        }
        if (dArr2 == null) {
            dArr2 = isLeftRight() ? initStateProbsLeftRight(i) : isRandomStateInitializers() ? initStateProbsRandom(i, random) : initStateProbsUniform(i);
        }
        for (int i2 = 0; i2 < i; i2++) {
            MultivariateNormalHMMEstimator multivariateNormalHMMEstimator = new MultivariateNormalHMMEstimator(getNumStates(), false);
            this.estimators[i2] = multivariateNormalHMMEstimator;
            multivariateNormalHMMEstimator.setCovarianceType(this.m_CovarianceType);
            multivariateNormalHMMEstimator.setTied(isTied());
            multivariateNormalHMMEstimator.setState0Probabilities(dArr[i2]);
            multivariateNormalHMMEstimator.setStateProbabilities(dArr2[i2]);
        }
        if (instances == null) {
            initGaussianOutputProbsRandom(i, doubleVectorArr, matrixArr);
        } else if (isLeftRight()) {
            initGaussianOutputProbsAllData(i, instances, doubleVectorArr, matrixArr);
        } else {
            initGaussianOutputProbsCluster(i, instances, doubleVectorArr, matrixArr);
        }
    }

    public void buildClassifier(Instances instances) throws Exception {
        System.out.println("starting build classifier");
        if (instances.classIndex() < 0) {
            System.err.println("could not find class index");
            return;
        }
        if (!instances.classAttribute().isNominal()) {
            System.err.println("class attribute is not nominal");
            return;
        }
        this.m_SeqAttr = -1;
        this.m_NumOutputs = 0;
        int i = 0;
        while (true) {
            if (i >= instances.numAttributes()) {
                break;
            }
            Attribute attribute = instances.attribute(i);
            if (attribute.isRelationValued()) {
                if (attribute.relation().attribute(0).isNominal()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i) {
                        throw new AssertionError();
                    }
                    this.m_NumOutputs = attribute.relation().numDistinctValues(0);
                }
                if (attribute.relation().attribute(0).isNumeric()) {
                    this.m_SeqAttr = attribute.index();
                    if (!$assertionsDisabled && this.m_SeqAttr != i) {
                        throw new AssertionError();
                    }
                    setNumeric(true);
                    this.m_NumOutputs = -1;
                    this.m_OutputDimension = attribute.relation().numAttributes();
                }
            } else {
                i++;
            }
        }
        if (this.estimators == null) {
            initEstimators(instances.numClasses(), instances);
        }
        for (int i2 = 0; i2 < this.estimators.length; i2++) {
            System.out.println(i2 + " " + this.estimators[i2]);
        }
        if (this.m_SeqAttr < 0) {
            System.err.println("Could not find a relational attribute corresponding to the sequence");
            return;
        }
        if (instances.numInstances() == 0) {
            System.err.println("No instances found");
            return;
        }
        double d = -1.0E7d;
        for (int i3 = 0; i3 < 100; i3++) {
            double EMStep = EMStep(instances);
            if (Math.abs((EMStep - d) / EMStep) < getIterationCutoff()) {
                break;
            }
            d = EMStep;
        }
        for (int i4 = 0; i4 < this.estimators.length; i4++) {
            System.out.println(i4 + " " + this.estimators[i4]);
        }
    }

    public Instances sample(int i, int i2) {
        if (this.m_rand == null) {
            this.m_rand = new Random(getSeed());
        }
        return sample(i, i2, this.m_rand);
    }

    public Instances sample(int i, int i2, Random random) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add("seq_" + i3);
        }
        arrayList.add(new Attribute("seq-id", arrayList2));
        ArrayList arrayList3 = new ArrayList();
        for (int i4 = 0; i4 < this.estimators.length; i4++) {
            arrayList3.add("class_" + i4);
        }
        arrayList.add(new Attribute("class", arrayList3));
        ArrayList arrayList4 = new ArrayList();
        if (isNumeric()) {
            for (int i5 = 0; i5 < getOutputDimension(); i5++) {
                arrayList4.add(new Attribute("output_" + i5));
            }
        } else {
            ArrayList arrayList5 = new ArrayList();
            for (int i6 = 0; i6 < getNumOutputs(); i6++) {
                arrayList5.add("output_" + i6);
            }
            arrayList4.add(new Attribute("output", arrayList5));
        }
        arrayList.add(new Attribute("sequence", new Instances("seq", arrayList4, 0)));
        Instances instances = new Instances("test", arrayList, i);
        instances.setClassIndex(1);
        for (int i7 = 0; i7 < i; i7++) {
            instances.add(new DenseInstance(3));
            Instance lastInstance = instances.lastInstance();
            lastInstance.setValue(0, (String) arrayList2.get(i7));
            int nextInt = this.m_rand.nextInt(arrayList3.size());
            lastInstance.setValue(1, (String) arrayList3.get(nextInt));
            HMMEstimator hMMEstimator = this.estimators[nextInt];
            Instances instances2 = new Instances((String) arrayList2.get(i7), arrayList4, i2);
            int Sample0 = hMMEstimator.Sample0(instances2, random);
            for (int i8 = 1; i8 < i2; i8++) {
                Sample0 = hMMEstimator.Sample(instances2, Sample0, random);
            }
            lastInstance.setValue(instances.attribute(2), r0.addRelation(instances2));
        }
        return instances;
    }

    public static void main(String[] strArr) {
        runClassifier(new HMM(), strArr);
    }

    static {
        $assertionsDisabled = !HMM.class.desiredAssertionStatus();
        TAGS_COVARIANCE_TYPE = new Tag[]{new Tag(0, "Full matrix (unconstrianed)"), new Tag(1, "Diagonal matrix (no correlation between data attributes)"), new Tag(2, "Spherical matrix (all attributes have the same variance)")};
    }
}
