package weka.classifiers.functions.coordinate_descent;

import java.util.ArrayList;
import java.util.Iterator;
import weka.classifiers.AbstractClassifier;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/functions/coordinate_descent/CoordinateDescent.class */
public class CoordinateDescent extends AbstractClassifier {
    private static final long serialVersionUID = 3058063989927918091L;
    protected Instances m_dataset;
    protected int m_numPredictors;
    protected int m_numInstances;
    protected double m_lambda;
    protected double m_alpha;
    protected int m_classIndex;
    protected double[][] m_covariance_matrix;
    protected boolean[] m_covariances_filled;
    protected double m_class_stdDev;
    protected double[] m_coefficients;
    protected boolean m_covarianceMode;
    protected boolean m_sparse;
    protected double m_threshold;
    protected ArrayList<ArrayList<Integer>> m_sparseIndices;
    protected double m_lambdaZero;
    protected double m_softThreshold;
    protected double m_denominator_param;
    protected double m_ridgeCoeff;
    protected double[] m_weighted_means;
    protected double[] m_weighted_sumSquares;
    protected double[] m_residual_covariances;
    protected double[] m_partResiduals;
    protected double m_squaredError_term;
    protected double m_significance_checkVal;
    protected int m_maxIt;
    protected double m_classMean = 0.0d;
    protected double m_partResidual_sum = 0.0d;
    protected double m_partResidual_sumSquared = 0.0d;
    protected double m_sumOfWeights = 0.0d;
    protected double m_unscaled_penalty = 0.0d;
    protected double m_coeffsMeans_product = 0.0d;

    public CoordinateDescent(Instances instances, double d, double d2, int i, boolean z, boolean z2) {
        this.m_class_stdDev = 0.0d;
        this.m_covarianceMode = true;
        this.m_sparse = false;
        this.m_dataset = instances;
        this.m_numPredictors = instances.numAttributes() - 1;
        this.m_numInstances = instances.numInstances();
        this.m_alpha = d;
        this.m_classIndex = instances.classIndex();
        this.m_covariance_matrix = new double[this.m_numPredictors][this.m_numPredictors];
        this.m_covariances_filled = new boolean[this.m_numPredictors];
        this.m_coefficients = new double[this.m_numPredictors];
        this.m_covarianceMode = z;
        this.m_threshold = d2;
        this.m_maxIt = i;
        this.m_sparse = z2;
        this.m_ridgeCoeff = (1.0d - this.m_alpha) / 2.0d;
        this.m_weighted_means = new double[this.m_numPredictors];
        this.m_weighted_sumSquares = new double[this.m_numPredictors];
        this.m_residual_covariances = new double[this.m_numPredictors];
        this.m_partResiduals = new double[this.m_numInstances];
        if (z2) {
            this.m_sparse = true;
            this.m_sparseIndices = new ArrayList<>();
            for (int i2 = 0; i2 < this.m_numPredictors; i2++) {
                this.m_sparseIndices.add(new ArrayList<>());
            }
        }
        for (int i3 = 0; i3 < this.m_numInstances; i3++) {
            Instance instance = instances.instance(i3);
            this.m_sumOfWeights += instance.weight();
            this.m_classMean += instance.value(this.m_classIndex) * instance.weight();
            this.m_class_stdDev += instance.value(this.m_classIndex) * instance.value(this.m_classIndex) * instance.weight();
            int i4 = 0;
            while (i4 <= this.m_numPredictors) {
                if (i4 != this.m_classIndex) {
                    int i5 = i4 > this.m_classIndex ? i4 - 1 : i4;
                    double[] dArr = this.m_weighted_means;
                    dArr[i5] = dArr[i5] + (instance.value(i4) * instance.weight());
                    double[] dArr2 = this.m_weighted_sumSquares;
                    dArr2[i5] = dArr2[i5] + (instance.value(i4) * instance.value(i4) * instance.weight());
                    double[] dArr3 = this.m_residual_covariances;
                    dArr3[i5] = dArr3[i5] + (instance.value(i4) * instance.value(this.m_classIndex) * instance.weight());
                    if (this.m_sparse && instance.value(i4) != 0.0d) {
                        this.m_sparseIndices.get(i5).add(Integer.valueOf(i3));
                    }
                }
                i4++;
            }
        }
        this.m_classMean /= this.m_sumOfWeights;
        this.m_class_stdDev /= this.m_sumOfWeights;
        this.m_class_stdDev -= this.m_classMean * this.m_classMean;
        this.m_class_stdDev = Math.sqrt(this.m_class_stdDev);
        double d3 = 0.0d;
        for (int i6 = 0; i6 < this.m_numPredictors; i6++) {
            double[] dArr4 = this.m_weighted_means;
            int i7 = i6;
            dArr4[i7] = dArr4[i7] / this.m_sumOfWeights;
            double[] dArr5 = this.m_weighted_sumSquares;
            int i8 = i6;
            dArr5[i8] = dArr5[i8] - ((this.m_weighted_means[i6] * this.m_weighted_means[i6]) * this.m_sumOfWeights);
            double[] dArr6 = this.m_residual_covariances;
            int i9 = i6;
            dArr6[i9] = dArr6[i9] - ((this.m_weighted_means[i6] * this.m_classMean) * this.m_sumOfWeights);
            double[] dArr7 = this.m_residual_covariances;
            int i10 = i6;
            dArr7[i10] = dArr7[i10] / this.m_class_stdDev;
            if (Math.abs(this.m_residual_covariances[i6]) > d3) {
                d3 = Math.abs(this.m_residual_covariances[i6]);
            }
        }
        this.m_lambdaZero = d3 / (this.m_alpha * this.m_sumOfWeights);
        for (int i11 = 0; i11 < this.m_numInstances; i11++) {
            Instance instance2 = instances.instance(i11);
            this.m_partResiduals[i11] = (instance2.value(this.m_classIndex) - this.m_classMean) / this.m_class_stdDev;
            this.m_partResidual_sum += instance2.weight() * this.m_partResiduals[i11];
            this.m_partResidual_sumSquared += instance2.weight() * this.m_partResiduals[i11] * this.m_partResiduals[i11];
        }
        this.m_squaredError_term = this.m_partResidual_sumSquared / (2.0d * this.m_sumOfWeights);
        this.m_significance_checkVal = this.m_threshold * this.m_squaredError_term;
    }

    public void run() {
        if (this.m_covarianceMode) {
            covarianceUpdateMethod();
        } else {
            naiveUpdateMethod();
        }
    }

    public void covarianceUpdateMethod() {
        boolean z = false;
        for (int i = 0; !z && i < this.m_maxIt; i++) {
            z = true;
            int i2 = 0;
            while (i2 <= this.m_numPredictors) {
                if (i2 != this.m_classIndex) {
                    int i3 = i2 > this.m_classIndex ? i2 - 1 : i2;
                    double computeNewCoefficient = computeNewCoefficient(this.m_residual_covariances[i3] + (this.m_coefficients[i3] * this.m_weighted_sumSquares[i3]), i3);
                    if (computeNewCoefficient != this.m_coefficients[i3]) {
                        if (!this.m_covariances_filled[i3]) {
                            int i4 = 0;
                            while (i4 <= this.m_numPredictors) {
                                if (i4 != this.m_classIndex) {
                                    int i5 = i4 > this.m_classIndex ? i4 - 1 : i4;
                                    if (this.m_covariances_filled[i5]) {
                                        this.m_covariance_matrix[i3][i5] = this.m_covariance_matrix[i5][i3];
                                    } else {
                                        this.m_covariance_matrix[i3][i5] = computeWeightedCovariance(i2, i4, i3, i5);
                                    }
                                }
                                i4++;
                            }
                            this.m_covariances_filled[i3] = true;
                        }
                        double d = computeNewCoefficient - this.m_coefficients[i3];
                        for (int i6 = 0; i6 < this.m_numPredictors; i6++) {
                            double[] dArr = this.m_residual_covariances;
                            int i7 = i6;
                            dArr[i7] = dArr[i7] - (d * this.m_covariance_matrix[i3][i6]);
                        }
                        double d2 = ((((this.m_coefficients[i3] + computeNewCoefficient) * d) * this.m_weighted_sumSquares[i3]) - ((2.0d * d) * (this.m_residual_covariances[i3] + (computeNewCoefficient * this.m_weighted_sumSquares[i3])))) / (2.0d * this.m_sumOfWeights);
                        double abs = (this.m_ridgeCoeff * (this.m_coefficients[i3] + computeNewCoefficient) * d) + (this.m_alpha * (Math.abs(computeNewCoefficient) - Math.abs(this.m_coefficients[i3])));
                        this.m_squaredError_term += d2;
                        this.m_unscaled_penalty += abs;
                        double d3 = d2 + (this.m_lambda * abs);
                        if (z && checkSignificance(d3)) {
                            z = false;
                        }
                    }
                    this.m_coefficients[i3] = computeNewCoefficient;
                }
                i2++;
            }
        }
        this.m_coeffsMeans_product = 0.0d;
        for (int i8 = 0; i8 < this.m_numPredictors; i8++) {
            this.m_coeffsMeans_product += this.m_coefficients[i8] * this.m_weighted_means[i8];
        }
    }

    public void naiveUpdateMethod() {
        boolean z = false;
        for (int i = 0; !z && i < this.m_maxIt; i++) {
            z = true;
            int i2 = 0;
            while (i2 <= this.m_numPredictors) {
                if (i2 != this.m_classIndex) {
                    int i3 = i2 > this.m_classIndex ? i2 - 1 : i2;
                    double d = (this.m_coefficients[i3] * this.m_weighted_sumSquares[i3]) - (this.m_weighted_means[i3] * this.m_partResidual_sum);
                    if (this.m_sparse) {
                        Iterator<Integer> it = this.m_sparseIndices.get(i3).iterator();
                        while (it.hasNext()) {
                            int intValue = it.next().intValue();
                            Instance instance = this.m_dataset.instance(intValue);
                            d += this.m_partResiduals[intValue] * instance.value(i3) * instance.weight();
                        }
                    } else {
                        for (int i4 = 0; i4 < this.m_numInstances; i4++) {
                            Instance instance2 = this.m_dataset.instance(i4);
                            d += this.m_partResiduals[i4] * instance2.value(i3) * instance2.weight();
                        }
                    }
                    double computeNewCoefficient = computeNewCoefficient(d, i3);
                    if (computeNewCoefficient != this.m_coefficients[i3]) {
                        double d2 = computeNewCoefficient - this.m_coefficients[i3];
                        if (this.m_sparse) {
                            Iterator<Integer> it2 = this.m_sparseIndices.get(i3).iterator();
                            while (it2.hasNext()) {
                                int intValue2 = it2.next().intValue();
                                Instance instance3 = this.m_dataset.instance(intValue2);
                                double value = d2 * instance3.value(i3);
                                double[] dArr = this.m_partResiduals;
                                dArr[intValue2] = dArr[intValue2] - value;
                                this.m_partResidual_sum -= instance3.weight() * value;
                                this.m_partResidual_sumSquared -= (instance3.weight() * value) * ((2.0d * this.m_partResiduals[intValue2]) + value);
                            }
                        } else {
                            for (int i5 = 0; i5 < this.m_numInstances; i5++) {
                                Instance instance4 = this.m_dataset.instance(i5);
                                double value2 = d2 * instance4.value(i3);
                                double[] dArr2 = this.m_partResiduals;
                                int i6 = i5;
                                dArr2[i6] = dArr2[i6] - value2;
                                this.m_partResidual_sum -= instance4.weight() * value2;
                                this.m_partResidual_sumSquared -= (instance4.weight() * value2) * ((2.0d * this.m_partResiduals[i5]) + value2);
                            }
                        }
                        this.m_coeffsMeans_product += d2 * this.m_weighted_means[i3];
                        double d3 = this.m_squaredError_term + (this.m_lambda * this.m_unscaled_penalty);
                        this.m_squaredError_term = ((this.m_partResidual_sumSquared + ((2.0d * this.m_coeffsMeans_product) * this.m_partResidual_sum)) / (2.0d * this.m_sumOfWeights)) + (Math.pow(this.m_coeffsMeans_product, 2.0d) / 2.0d);
                        this.m_unscaled_penalty += (this.m_ridgeCoeff * (this.m_coefficients[i3] + computeNewCoefficient) * d2) + (this.m_alpha * (Math.abs(computeNewCoefficient) - Math.abs(this.m_coefficients[i3])));
                        double d4 = (this.m_squaredError_term + (this.m_lambda * this.m_unscaled_penalty)) - d3;
                        if (z && checkSignificance(d4)) {
                            z = false;
                        }
                    }
                    this.m_coefficients[i3] = computeNewCoefficient;
                }
                i2++;
            }
        }
    }

    public double computeWeightedCovariance(int i, int i2, int i3, int i4) {
        double d = 0.0d;
        if (this.m_sparse) {
            Iterator<Integer> it = this.m_sparseIndices.get(i3).iterator();
            while (it.hasNext()) {
                Instance instance = this.m_dataset.instance(it.next().intValue());
                d += instance.value(i) * instance.value(i2) * instance.weight();
            }
        } else {
            for (int i5 = 0; i5 < this.m_numInstances; i5++) {
                Instance instance2 = this.m_dataset.instance(i5);
                d += instance2.value(i) * instance2.value(i2) * instance2.weight();
            }
        }
        return d - ((this.m_weighted_means[i3] * this.m_weighted_means[i4]) * this.m_sumOfWeights);
    }

    public boolean checkSignificance(double d) {
        return Math.abs(d) > this.m_significance_checkVal;
    }

    public double computeNewCoefficient(double d, int i) {
        if (this.m_softThreshold >= Math.abs(d)) {
            return 0.0d;
        }
        double d2 = this.m_weighted_sumSquares[i] + this.m_denominator_param;
        if (d > 0.0d) {
            return (d - this.m_softThreshold) / d2;
        }
        if (d < 0.0d) {
            return (d + this.m_softThreshold) / d2;
        }
        return 0.0d;
    }

    public double classifyInstance(Instance instance) throws Exception {
        double d = -this.m_coeffsMeans_product;
        int i = 0;
        while (i <= this.m_numPredictors) {
            if (i != this.m_classIndex) {
                d += this.m_coefficients[i > this.m_classIndex ? i - 1 : i] * instance.value(i);
            }
            i++;
        }
        return (d * this.m_class_stdDev) + this.m_classMean;
    }

    public void setLambda(double d) {
        this.m_lambda = d;
        this.m_softThreshold = this.m_lambda * this.m_alpha * this.m_sumOfWeights;
        this.m_denominator_param = this.m_lambda * (1.0d - this.m_alpha) * this.m_sumOfWeights;
    }

    public void setThreshold(double d) {
        this.m_significance_checkVal *= d / this.m_threshold;
        this.m_threshold = d;
    }

    public double get_classMean() {
        return this.m_classMean;
    }

    public double get_class_stdDev() {
        return this.m_class_stdDev;
    }

    public double[] getCoefficients() {
        return this.m_coefficients;
    }

    public double getLambdaZero() {
        return this.m_lambdaZero;
    }

    public double get_coeffsMeans_product() {
        return this.m_coeffsMeans_product;
    }

    public void buildClassifier(Instances instances) throws Exception {
    }
}
