package weka.classifiers.functions;

import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.coordinate_descent.CoordinateDescent;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/functions/ElasticNet.class */
public class ElasticNet extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler {
    private static final long serialVersionUID = -257033407031904867L;
    protected int m_numPredictors;
    protected int m_numInstances;
    protected int m_classIndex;
    protected String m_class_name;
    protected String[] m_predictor_names;
    protected String m_train_time;
    protected double[] m_lambda_values;
    protected CoordinateDescent m_modelUsed;
    protected String m_lambda_seq = "";
    protected double m_alpha = 0.001d;
    protected int m_numModels = 100;
    protected int m_numInnerFolds = 10;
    protected double m_epsilon = 1.0E-4d;
    protected int m_bestModel_index = 0;
    protected double m_threshold = 1.0E-7d;
    protected int m_maxIt = 10000000;
    protected boolean m_covarianceMode = true;
    protected boolean m_sparse = false;
    protected boolean m_stderr_rule = false;
    protected boolean m_additionalStats = false;

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11970 $");
    }

    public void buildClassifier(Instances instances) throws Exception {
        int i;
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        String[] options = getOptions();
        StringBuilder sb = new StringBuilder();
        for (String str : options) {
            sb.append(str).append(" ");
        }
        this.m_numPredictors = instances2.numAttributes() - 1;
        this.m_numInstances = instances2.numInstances();
        this.m_classIndex = instances2.classIndex();
        double nanoTime = System.nanoTime();
        instances2.randomize(new Random(1L));
        double[] dArr = new double[this.m_numModels];
        double[] dArr2 = new double[this.m_numModels];
        this.m_modelUsed = new CoordinateDescent(instances2, this.m_alpha, this.m_threshold, this.m_maxIt, this.m_covarianceMode, this.m_sparse);
        if (this.m_lambda_values == null) {
            double lambdaZero = this.m_modelUsed.getLambdaZero();
            this.m_lambda_values = logspace(this.m_epsilon * lambdaZero, lambdaZero, this.m_numModels);
        }
        for (int i2 = 0; i2 < this.m_numInnerFolds; i2++) {
            Instances trainCV = instances2.trainCV(this.m_numInnerFolds, i2);
            Instances testCV = instances2.testCV(this.m_numInnerFolds, i2);
            CoordinateDescent coordinateDescent = new CoordinateDescent(trainCV, this.m_alpha, this.m_threshold, this.m_maxIt, this.m_covarianceMode, this.m_sparse);
            for (int i3 = 0; i3 < this.m_numModels; i3++) {
                coordinateDescent.setLambda(this.m_lambda_values[i3]);
                coordinateDescent.run();
                Evaluation evaluation = new Evaluation(instances2);
                evaluation.evaluateModel(coordinateDescent, testCV, new Object[0]);
                double errorRate = evaluation.errorRate();
                double d = errorRate * errorRate;
                int i4 = i3;
                dArr[i4] = dArr[i4] + d;
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + (d * d);
            }
        }
        if (this.m_numInnerFolds > 0) {
            for (int i6 = 0; i6 < this.m_numModels; i6++) {
                int i7 = i6;
                dArr[i7] = dArr[i7] / this.m_numInnerFolds;
                int i8 = i6;
                dArr2[i8] = dArr2[i8] - ((this.m_numInnerFolds * dArr[i6]) * dArr[i6]);
                int i9 = i6;
                dArr2[i9] = dArr2[i9] / (this.m_numInnerFolds * (this.m_numInnerFolds - 1));
                dArr2[i6] = Math.sqrt(dArr2[i6]);
            }
        }
        this.m_bestModel_index = minIndex(dArr);
        int i10 = 0;
        do {
            i = i10;
            i10++;
        } while (dArr[i] > dArr[this.m_bestModel_index] + dArr2[this.m_bestModel_index]);
        int i11 = i10 - 1;
        if (this.m_stderr_rule) {
            this.m_bestModel_index = i11;
        }
        for (int i12 = 0; i12 <= this.m_bestModel_index; i12++) {
            this.m_modelUsed.setLambda(this.m_lambda_values[i12]);
            this.m_modelUsed.run();
        }
        double nanoTime2 = System.nanoTime();
        this.m_class_name = instances2.attribute(this.m_classIndex).name();
        this.m_predictor_names = new String[this.m_numPredictors];
        int i13 = 0;
        while (i13 <= this.m_numPredictors) {
            if (i13 != this.m_classIndex) {
                this.m_predictor_names[i13 > this.m_classIndex ? i13 - 1 : i13] = instances2.attribute(i13).name();
            }
            i13++;
        }
        this.m_train_time = String.format("%2.3f", Double.valueOf((nanoTime2 - nanoTime) / 1000000.0d)) + " ms";
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.m_modelUsed.classifyInstance(instance);
    }

    public double[] logspace(double d, double d2, int i) {
        double[] dArr = new double[i];
        double pow = Math.pow(d / d2, 1.0d / (i - 1));
        dArr[0] = d2;
        for (int i2 = 1; i2 < i; i2++) {
            dArr[i2] = dArr[i2 - 1] * pow;
        }
        return dArr;
    }

    public int minIndex(double[] dArr) {
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] < dArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    public String toString() {
        if (this.m_modelUsed == null) {
            return "Elastic net";
        }
        double[] coefficients = this.m_modelUsed.getCoefficients();
        double d = this.m_modelUsed.get_classMean();
        double d2 = this.m_modelUsed.get_class_stdDev();
        double d3 = d - (this.m_modelUsed.get_coeffsMeans_product() * d2);
        StringBuilder sb = new StringBuilder("Lambda sequence:\n");
        int length = sb.length();
        for (double d4 : this.m_lambda_values) {
            sb.append(String.format("%2.3f", Double.valueOf(d4 * d2))).append(',');
            int length2 = sb.length();
            if (length2 - length > 120) {
                length = length2;
                sb.append("\n");
            }
        }
        sb.deleteCharAt(sb.length() - 1);
        sb.append("\nBest model index: " + this.m_bestModel_index + "\n\n");
        if (!this.m_additionalStats) {
            sb = new StringBuilder();
        }
        sb.append(this.m_class_name + " = \n");
        for (int i = 0; i < this.m_numPredictors; i++) {
            if (coefficients[i] != 0.0d) {
                sb.append(String.format("%2.3f", Double.valueOf(coefficients[i] * d2)) + " * " + this.m_predictor_names[i] + " + \n");
            }
        }
        sb.append(String.format("%2.3f", Double.valueOf(d3)));
        return sb.toString();
    }

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tWhether to use Method 2 (covariance updates), y or n", "m2", 1, "-m2 <mode>"));
        vector.addElement(new Option("\tSet alpha value", "alpha", 1, "-alpha <alpha>"));
        vector.addElement(new Option("\tProvide custom lambda sequence of comma seperated non-negative floating point values OR leave blank to let classifier build own sequence", "lambda_seq", 1, "-lambda_seq <lambda_seq>"));
        vector.addElement(new Option("\tSet convergence threshold", "thr", 1, "-thr <thr>"));
        vector.addElement(new Option("\tSet maximum iterations for convergence", "mxit", 1, "-mxit <mxit>"));
        vector.addElement(new Option("\tSet number of models for pathwise descent", "numModels", 1, "-numModels <numModels>"));
        vector.addElement(new Option("\tSet number of folds for inner CV", "infolds", 1, "-infolds <infolds>"));
        vector.addElement(new Option("\tSet epsilon value for pathwise descent", "eps", 1, "-eps <eps>"));
        vector.addElement(new Option("\tSet whether to turn on sparse updates, y or n", "sparse", 1, "-sparse <sparse>"));
        vector.addElement(new Option("\tSelect model based on 1 S.E rule, y or n", "stderr_rule", 1, "-stderr_rule <stderr_rule>"));
        vector.addElement(new Option("\tWhether to print additional statistics, y or n", "addStats", 1, "-addStats <addStats>"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public void validate() {
        if (this.m_alpha > 1.0d) {
            this.m_alpha = 1.0d;
        } else if (this.m_alpha < 0.001d) {
            this.m_alpha = 0.001d;
        }
        if (this.m_epsilon > 1.0d || this.m_epsilon < 0.0d) {
            this.m_epsilon = 1.0E-4d;
        }
        if (this.m_numModels < 2) {
            this.m_numModels = 100;
        }
        if (this.m_maxIt <= 0) {
            this.m_maxIt = 10000000;
        }
        if (this.m_threshold <= 0.0d) {
            this.m_threshold = 1.0E-7d;
        }
        if (this.m_numInnerFolds < 2) {
            this.m_numInnerFolds = 10;
        }
        String[] split = this.m_lambda_seq.split(",");
        int length = split.length;
        double[] dArr = new double[length];
        int i = 0;
        while (i < length) {
            try {
                dArr[i] = Double.parseDouble(split[i]);
                if (dArr[i] < 0.0d) {
                    break;
                } else {
                    i++;
                }
            } catch (NumberFormatException e) {
            }
        }
        if (i != length) {
            this.m_lambda_seq = "";
            this.m_lambda_values = null;
            return;
        }
        Arrays.sort(dArr);
        this.m_lambda_values = new double[length];
        this.m_numModels = length;
        while (true) {
            i--;
            if (i < 0) {
                break;
            } else {
                this.m_lambda_values[(length - i) - 1] = dArr[i];
            }
        }
        if (length == 1) {
            this.m_numInnerFolds = 0;
        }
    }

    public String globalInfo() {
        return "Class for solving the 'elastic net' problem for linear regression using coordinate descent. This is a Java implementation of a component of the R package glmnet. Can perform attribute selection based on the tuning parameters alpha and lambda. Model can deal with weighted instances. For more information, refer to the report PDF included in the package distribution.";
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("m2", strArr);
        String option2 = Utils.getOption("alpha", strArr);
        String option3 = Utils.getOption("lambda_seq", strArr);
        String option4 = Utils.getOption("thr", strArr);
        String option5 = Utils.getOption("mxit", strArr);
        String option6 = Utils.getOption("numModels", strArr);
        String option7 = Utils.getOption("infolds", strArr);
        String option8 = Utils.getOption("eps", strArr);
        String option9 = Utils.getOption("sparse", strArr);
        String option10 = Utils.getOption("stderr_rule", strArr);
        String option11 = Utils.getOption("addStats", strArr);
        if (!"".equals(option3)) {
            this.m_lambda_seq = option3;
        }
        if (!"".equals(option2)) {
            try {
                this.m_alpha = Double.parseDouble(option2);
            } catch (NumberFormatException e) {
            }
        }
        if (!"".equals(option4)) {
            try {
                this.m_threshold = Double.parseDouble(option4);
            } catch (NumberFormatException e2) {
            }
        }
        if (!"".equals(option8)) {
            try {
                this.m_epsilon = Double.parseDouble(option8);
            } catch (NumberFormatException e3) {
            }
        }
        if (!"".equals(option6)) {
            try {
                this.m_numModels = Integer.parseInt(option6);
            } catch (NumberFormatException e4) {
            }
        }
        if (!"".equals(option7)) {
            try {
                this.m_numInnerFolds = Integer.parseInt(option7);
            } catch (NumberFormatException e5) {
            }
        }
        if (!"".equals(option5)) {
            try {
                this.m_maxIt = Integer.parseInt(option5);
            } catch (NumberFormatException e6) {
            }
        }
        if (!"".equals(option)) {
            this.m_covarianceMode = !"n".equalsIgnoreCase(option);
        }
        if (!"".equals(option9)) {
            this.m_sparse = "y".equalsIgnoreCase(option9);
        }
        if (!"".equals(option10)) {
            this.m_stderr_rule = "y".equalsIgnoreCase(option10);
        }
        if (!"".equals(option11)) {
            this.m_additionalStats = "y".equalsIgnoreCase(option11);
        }
        super.setOptions(strArr);
        validate();
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-m2");
        vector.add(this.m_covarianceMode ? "y" : "n");
        vector.add("-alpha");
        vector.add(String.valueOf(this.m_alpha));
        vector.add("-lambda_seq");
        vector.add(String.valueOf(this.m_lambda_seq));
        vector.add("-thr");
        vector.add(String.valueOf(this.m_threshold));
        vector.add("-mxit");
        vector.add(String.valueOf(this.m_maxIt));
        vector.add("-numModels");
        vector.add(String.valueOf(this.m_numModels));
        vector.add("-infolds");
        vector.add(String.valueOf(this.m_numInnerFolds));
        vector.add("-eps");
        vector.add(String.valueOf(this.m_epsilon));
        vector.add("-sparse");
        vector.add(this.m_sparse ? "y" : "n");
        vector.add("-stderr_rule");
        vector.add(this.m_stderr_rule ? "y" : "n");
        vector.add("-addStats");
        vector.add(this.m_additionalStats ? "y" : "n");
        Collections.addAll(vector, super.getOptions());
        validate();
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        return capabilities;
    }

    public static void main(String[] strArr) {
        runClassifier(new ElasticNet(), strArr);
    }

    public String getCustom_lambda_sequence() {
        return this.m_lambda_seq;
    }

    public void setCustom_lambda_sequence(String str) {
        this.m_lambda_seq = str;
    }

    public String custom_lambda_sequenceTipText() {
        return "Provide custom lambda sequence of comma seperated non-negative floating point values OR leave blank to let classifier build own sequence";
    }

    public double getAlpha() {
        return this.m_alpha;
    }

    public void setAlpha(double d) {
        this.m_alpha = d;
    }

    public String alphaTipText() {
        return "Set the alpha value";
    }

    public boolean getUse_stderr_rule() {
        return this.m_stderr_rule;
    }

    public void setUse_stderr_rule(boolean z) {
        this.m_stderr_rule = z;
    }

    public String use_stderr_ruleTipText() {
        return "If true, the one standard error rule is applied for choosing lambda";
    }

    public boolean getUse_method2() {
        return this.m_covarianceMode;
    }

    public void setUse_method2(boolean z) {
        this.m_covarianceMode = z;
    }

    public String use_method2TipText() {
        return "If true, the covariance update method is used";
    }

    public int getNumModels() {
        return this.m_numModels;
    }

    public void setNumModels(int i) {
        this.m_numModels = i;
    }

    public String numModelsTipText() {
        return "Set the length of the lambda sequence to be generated";
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public void setEpsilon(double d) {
        this.m_epsilon = d;
    }

    public String epsilonTipText() {
        return "Set the epsilon value for generating the lambda sequence";
    }

    public int getNumInnerFolds() {
        return this.m_numInnerFolds;
    }

    public void setNumInnerFolds(int i) {
        this.m_numInnerFolds = i;
    }

    public String numInnerFoldsTipText() {
        return "Set the number of folds for internal cross validation";
    }

    public double getThreshold() {
        return this.m_threshold;
    }

    public void setThreshold(double d) {
        this.m_threshold = d;
    }

    public String thresholdTipText() {
        return "Set the convergence threshold";
    }

    public int getMaxIt() {
        return this.m_maxIt;
    }

    public void setMaxIt(int i) {
        this.m_maxIt = i;
    }

    public String maxItTipText() {
        return "Set the maximum number of iterations for a coordinate descent run";
    }

    public boolean getSparse() {
        return this.m_sparse;
    }

    public void setSparse(boolean z) {
        this.m_sparse = z;
    }

    public String sparseTipText() {
        return "If true, data is treated as sparse matrix";
    }

    public boolean getAdditionalStats() {
        return this.m_additionalStats;
    }

    public void setAdditionalStats(boolean z) {
        this.m_additionalStats = z;
    }

    public String additionalStatsTipText() {
        return "If true, additional statistics are printed";
    }
}
