package weka.classifiers.functions;

import java.util.Arrays;
import java.util.Random;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/functions/RBFClassifier.class */
public class RBFClassifier extends RBFModel implements WeightedInstancesHandler {
    private static final long serialVersionUID = -7847475556438394611L;

    @Override // weka.classifiers.functions.RBFModel
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    @Override // weka.classifiers.functions.RBFModel
    protected void initializeOutputLayer(Random random) {
        for (int i = 0; i < this.m_numUnits + 1; i++) {
            for (int i2 = 0; i2 < this.m_numClasses; i2++) {
                this.m_RBFParameters[this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i] = 0.1d * random.nextGaussian();
            }
        }
    }

    @Override // weka.classifiers.functions.RBFModel
    protected double calculateError(double[] dArr, Instance instance) {
        double d = 0.0d;
        int i = 0;
        while (i < this.m_numClasses) {
            double output = getOutput(i, dArr, null) - (((int) instance.value(this.m_classIndex)) == i ? 0.99d : 0.01d);
            d += instance.weight() * output * output;
            i++;
        }
        return d;
    }

    @Override // weka.classifiers.functions.RBFModel
    protected double postprocessError(double d) {
        double d2 = 0.0d;
        for (int i = 0; i < this.m_numUnits; i++) {
            for (int i2 = 0; i2 < this.m_numClasses; i2++) {
                d2 += this.m_RBFParameters[this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i] * this.m_RBFParameters[this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i];
            }
        }
        return (d + (this.m_ridge * d2)) / this.m_data.sumOfWeights();
    }

    @Override // weka.classifiers.functions.RBFModel
    protected void postprocessGradient(double[] dArr) {
        for (int i = 0; i < this.m_numUnits; i++) {
            for (int i2 = 0; i2 < this.m_numClasses; i2++) {
                int i3 = this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i;
                dArr[i3] = dArr[i3] + (this.m_ridge * 2.0d * this.m_RBFParameters[this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i]);
            }
        }
        double sumOfWeights = 1.0d / this.m_data.sumOfWeights();
        for (int i4 = 0; i4 < dArr.length; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] * sumOfWeights;
        }
    }

    @Override // weka.classifiers.functions.RBFModel
    protected void updateGradient(double[] dArr, Instance instance, double[] dArr2, double[] dArr3, double[] dArr4) {
        Arrays.fill(dArr4, 0.0d);
        int i = 0;
        while (i < this.m_numClasses) {
            double weight = instance.weight() * (getOutput(i, dArr2, dArr3) - (((int) instance.value(this.m_classIndex)) == i ? 0.99d : 0.01d)) * dArr3[0];
            if (weight > this.m_tolerance || weight < (-this.m_tolerance)) {
                int i2 = this.OFFSET_WEIGHTS + (i * (this.m_numUnits + 1));
                for (int i3 = 0; i3 < this.m_numUnits; i3++) {
                    int i4 = i3;
                    dArr4[i4] = dArr4[i4] + (weight * this.m_RBFParameters[i2 + i3]);
                }
                for (int i5 = 0; i5 < this.m_numUnits; i5++) {
                    int i6 = i2 + i5;
                    dArr[i6] = dArr[i6] + (weight * dArr2[i5]);
                }
                int i7 = i2 + this.m_numUnits;
                dArr[i7] = dArr[i7] + weight;
            }
            i++;
        }
    }

    protected double getOutput(int i, double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.m_numUnits; i2++) {
            d += this.m_RBFParameters[this.OFFSET_WEIGHTS + (i * (this.m_numUnits + 1)) + i2] * dArr[i2];
        }
        return sigmoid(-(d + this.m_RBFParameters[this.OFFSET_WEIGHTS + (i * (this.m_numUnits + 1)) + this.m_numUnits]), dArr2, 0);
    }

    protected double approxExp(double d, double[] dArr, int i) {
        double d2 = 1.0d + (d / 4096.0d);
        double d3 = d2 * d2;
        double d4 = d3 * d3;
        double d5 = d4 * d4;
        double d6 = d5 * d5;
        double d7 = d6 * d6;
        double d8 = d7 * d7;
        double d9 = d8 * d8;
        double d10 = d9 * d9;
        double d11 = d10 * d10;
        double d12 = d11 * d11;
        double d13 = d12 * d12;
        double d14 = d13 * d13;
        if (dArr != null) {
            dArr[i] = d14 / d2;
        }
        return d14;
    }

    protected double sigmoid(double d, double[] dArr, int i) {
        double d2 = 1.0d + (d / 4096.0d);
        double d3 = d2 * d2;
        double d4 = d3 * d3;
        double d5 = d4 * d4;
        double d6 = d5 * d5;
        double d7 = d6 * d6;
        double d8 = d7 * d7;
        double d9 = d8 * d8;
        double d10 = d9 * d9;
        double d11 = d10 * d10;
        double d12 = d11 * d11;
        double d13 = d12 * d12;
        double d14 = 1.0d / (1.0d + (d13 * d13));
        if (dArr != null) {
            dArr[i] = (d14 * (1.0d - d14)) / d2;
        }
        return d14;
    }

    @Override // weka.classifiers.functions.RBFModel
    protected double[] getDistribution(double[] dArr) {
        double[] dArr2 = new double[this.m_numClasses];
        for (int i = 0; i < this.m_numClasses; i++) {
            dArr2[i] = getOutput(i, dArr, null);
            if (dArr2[i] < 0.0d) {
                dArr2[i] = 0.0d;
            } else if (dArr2[i] > 1.0d) {
                dArr2[i] = 1.0d;
            }
        }
        Utils.normalize(dArr2);
        return dArr2;
    }

    public String toString() {
        if (this.m_RBFParameters == null) {
            return "Classifier not built yet.";
        }
        String str = "";
        for (int i = 0; i < this.m_numUnits; i++) {
            if (i > 0) {
                str = str + "\n\n";
            }
            String str2 = str + "Output weights for different classes:\n";
            for (int i2 = 0; i2 < this.m_numClasses; i2++) {
                str2 = str2 + this.m_RBFParameters[this.OFFSET_WEIGHTS + (i2 * (this.m_numUnits + 1)) + i] + "\t";
            }
            str = str2 + "\n\nUnit center:\n";
            for (int i3 = 0; i3 < this.m_numAttributes; i3++) {
                if (i3 != this.m_classIndex) {
                    str = str + this.m_RBFParameters[this.OFFSET_CENTERS + (i * this.m_numAttributes) + i3] + "\t";
                }
            }
            if (this.m_scaleOptimizationOption == 3) {
                str = str + "\n\nUnit scales:\n";
                for (int i4 = 0; i4 < this.m_numAttributes; i4++) {
                    if (i4 != this.m_classIndex) {
                        str = str + this.m_RBFParameters[this.OFFSET_SCALES + (i * this.m_numAttributes) + i4] + "\t";
                    }
                }
            } else if (this.m_scaleOptimizationOption == 2) {
                str = (str + "\n\nUnit scale:\n") + this.m_RBFParameters[this.OFFSET_SCALES + i] + "\t";
            }
        }
        if (this.m_scaleOptimizationOption == 1) {
            str = (str + "\n\nScale:\n") + this.m_RBFParameters[this.OFFSET_SCALES] + "\t";
        }
        if (this.m_useAttributeWeights) {
            str = str + "\n\nAttribute weights:\n";
            for (int i5 = 0; i5 < this.m_numAttributes; i5++) {
                if (i5 != this.m_classIndex) {
                    str = str + this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i5] + "\t";
                }
            }
        }
        String str3 = str + "\n\nBias weights for different classes:\n";
        for (int i6 = 0; i6 < this.m_numClasses; i6++) {
            str3 = str3 + this.m_RBFParameters[this.OFFSET_WEIGHTS + (i6 * (this.m_numUnits + 1)) + this.m_numUnits] + "\t";
        }
        return str3;
    }

    public static void main(String[] strArr) {
        runClassifier(new RBFClassifier(), strArr);
    }
}
