package weka.classifiers.functions;

import java.util.Arrays;
import java.util.Random;
import weka.classifiers.functions.activation.ActivationFunction;
import weka.classifiers.functions.activation.ApproximateSigmoid;
import weka.classifiers.functions.activation.Sigmoid;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:weka/classifiers/functions/MLPClassifier.class */
public class MLPClassifier extends MLPModel implements WeightedInstancesHandler {
    private static final long serialVersionUID = -3297474276438394644L;
    protected ActivationFunction m_OutputActivationFunction = null;

    @Override // weka.classifiers.functions.MLPModel
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.functions.MLPModel
    public Instances initializeClassifier(Instances instances, Random random) throws Exception {
        Instances initializeClassifier = super.initializeClassifier(instances, random);
        if (this.m_ActivationFunction instanceof ApproximateSigmoid) {
            this.m_OutputActivationFunction = new ApproximateSigmoid();
        } else {
            this.m_OutputActivationFunction = new Sigmoid();
        }
        if (initializeClassifier != null) {
            this.m_Filter = new Standardize();
            this.m_Filter.setInputFormat(initializeClassifier);
            initializeClassifier = Filter.useFilter(initializeClassifier, this.m_Filter);
        }
        return initializeClassifier;
    }

    @Override // weka.classifiers.functions.MLPModel
    protected double calculateErrorForOneInstance(double[] dArr, Instance instance) {
        double d = 0.0d;
        int i = 0;
        while (i < this.m_numClasses) {
            d += this.m_Loss.loss(this.m_OutputActivationFunction.activation(getOutput(i, dArr), null, 0), ((int) instance.value(this.m_classIndex)) == i ? 0.99d : 0.01d);
            i++;
        }
        return instance.weight() * d;
    }

    @Override // weka.classifiers.functions.MLPModel
    protected double[] computeDeltas(Instance instance, double[] dArr) {
        double[] dArr2 = new double[1];
        double[] dArr3 = new double[instance.numClasses()];
        Arrays.fill(dArr3, instance.weight());
        int i = 0;
        while (i < dArr3.length) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] * this.m_Loss.derivative(this.m_OutputActivationFunction.activation(getOutput(i, dArr), dArr2, 0), ((int) instance.value(this.m_classIndex)) == i ? 0.99d : 0.01d) * dArr2[0];
            i++;
        }
        return dArr3;
    }

    @Override // weka.classifiers.functions.MLPModel
    protected double[] postProcessDistribution(double[] dArr) {
        for (int i = 0; i < this.m_numClasses; i++) {
            dArr[i] = this.m_ActivationFunction.activation(dArr[i], null, 0);
            if (dArr[i] < 0.0d) {
                dArr[i] = 0.0d;
            } else if (dArr[i] > 1.0d) {
                dArr[i] = 1.0d;
            }
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        if (d <= 0.0d) {
            return null;
        }
        Utils.normalize(dArr, d);
        return dArr;
    }

    @Override // weka.classifiers.functions.MLPModel
    public String modelType() {
        return "MLPClassifier";
    }

    public static void main(String[] strArr) {
        runClassifier(new MLPClassifier(), strArr);
    }
}
