package rseslib.processing.classification.parameterised.knn;

import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Properties;
import rseslib.processing.classification.ClassifierWithDistributedDecision;
import rseslib.processing.classification.parameterised.AbstractParameterisedClassifier;
import rseslib.processing.classification.parameterised.ParameterisedTestResult;
import rseslib.processing.indexing.metric.TreeIndexer;
import rseslib.processing.metrics.MetricFactory;
import rseslib.processing.searching.metric.ArrayVicinityProvider;
import rseslib.processing.searching.metric.IndexingTreeVicinityProvider;
import rseslib.processing.searching.metric.VicinityProvider;
import rseslib.processing.transformation.AttributeTransformer;
import rseslib.processing.transformation.TableTransformer;
import rseslib.structure.attribute.NominalAttribute;
import rseslib.structure.data.DoubleData;
import rseslib.structure.data.DoubleDataWithDecision;
import rseslib.structure.metric.AbstractWeightedMetric;
import rseslib.structure.metric.Metric;
import rseslib.structure.metric.Neighbour;
import rseslib.structure.table.ArrayListDoubleDataTable;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.PropertyConfigurationException;
import rseslib.system.progress.EmptyProgress;
import rseslib.system.progress.MultiProgress;
import rseslib.system.progress.Progress;

/* loaded from: input_file:rseslib/processing/classification/parameterised/knn/KnnClassifier.class */
public class KnnClassifier extends AbstractParameterisedClassifier implements ClassifierWithDistributedDecision, Serializable {
    private static final long serialVersionUID = 1;
    public static final String WEIGHTING_METHOD_PROPERTY_NAME = "weightingMethod";
    public static final String INDEXING_PROPERTY_NAME = "indexing";
    public static final String LEARN_OPTIMAL_K_PROPERTY_NAME = "learnOptimalK";
    public static final String MAXIMAL_K_PROPERTY_NAME = "maxK";
    public static final String K_PROPERTY_NAME = "k";
    public static final String FILTER_NEIGHBOURS_PROPERTY_NAME = "filterNeighboursUsingRules";
    public static final String VOTING_PROPERTY_NAME = "voting";
    private ArrayList<DoubleData> m_OriginalData;
    AttributeTransformer m_Transformer;
    DoubleDataTable m_TransformedTrainTable;
    Metric m_Metric;
    VicinityProvider m_VicinityProvider;
    private CubeBasedNeighboursFilter m_NeighboursFilter;
    private boolean m_bSelfLearning;
    private int m_nMaxK;
    private NominalAttribute m_DecisionAttribute;
    private int m_nDefaultDec;
    private static /* synthetic */ int[] $SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting;

    /* loaded from: input_file:rseslib/processing/classification/parameterised/knn/KnnClassifier$Voting.class */
    public enum Voting {
        Equal,
        InverseDistance,
        InverseSquareDistance;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Voting[] valuesCustom() {
            Voting[] valuesCustom = values();
            int length = valuesCustom.length;
            Voting[] votingArr = new Voting[length];
            System.arraycopy(valuesCustom, 0, votingArr, 0, length);
            return votingArr;
        }
    }

    public KnnClassifier(Properties properties, DoubleDataTable doubleDataTable, Progress progress) throws PropertyConfigurationException, InterruptedException {
        super(properties, "k");
        this.m_bSelfLearning = false;
        MultiProgress multiProgress = new MultiProgress("Learning the k-nn classifier", progress, getBoolProperty("learnOptimalK") ? new int[]{40, 10, 50} : new int[]{80, 20});
        this.m_OriginalData = doubleDataTable.getDataObjects();
        this.m_Metric = MetricFactory.getMetric(getProperties(), doubleDataTable);
        this.m_Transformer = this.m_Metric.transformationOutside();
        this.m_TransformedTrainTable = doubleDataTable;
        if (this.m_Transformer != null) {
            this.m_TransformedTrainTable = TableTransformer.transform(doubleDataTable, this.m_Transformer);
        }
        if (this.m_Metric instanceof AbstractWeightedMetric) {
            MetricFactory.adjustWeights(getProperty("weightingMethod"), (AbstractWeightedMetric) this.m_Metric, this.m_TransformedTrainTable, multiProgress);
        }
        if (getBoolProperty(INDEXING_PROPERTY_NAME)) {
            this.m_VicinityProvider = new IndexingTreeVicinityProvider(null, this.m_Metric, new TreeIndexer(null).indexing(this.m_TransformedTrainTable.getDataObjects(), this.m_Metric, multiProgress));
        } else {
            multiProgress.set("Constructing simple vicinity provider", 1);
            this.m_VicinityProvider = new ArrayVicinityProvider(this.m_Metric, this.m_TransformedTrainTable.getDataObjects());
            multiProgress.step();
        }
        if (this.m_Metric instanceof AbstractWeightedMetric) {
            this.m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric) this.m_Metric);
        }
        this.m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
        this.m_DecisionAttribute = doubleDataTable.attributes().nominalDecisionAttribute();
        this.m_nDefaultDec = 0;
        for (int i = 1; i < doubleDataTable.getDecisionDistribution().length; i++) {
            if (doubleDataTable.getDecisionDistribution()[i] > doubleDataTable.getDecisionDistribution()[this.m_nDefaultDec]) {
                this.m_nDefaultDec = i;
            }
        }
        if (getBoolProperty("learnOptimalK")) {
            this.m_bSelfLearning = true;
            learnOptimalParameterValue(doubleDataTable, multiProgress);
            this.m_bSelfLearning = false;
        }
        makePropertyModifiable("k");
        makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
        makePropertyModifiable("voting");
    }

    public KnnClassifier(Properties properties, Metric metric, DoubleDataTable doubleDataTable, Progress progress) throws PropertyConfigurationException, InterruptedException {
        super(properties, "k");
        this.m_bSelfLearning = false;
        this.m_VicinityProvider = new IndexingTreeVicinityProvider(null, metric, new TreeIndexer(null).indexing(doubleDataTable.getDataObjects(), metric, progress));
        this.m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
        if (metric instanceof AbstractWeightedMetric) {
            this.m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric) metric);
        }
        this.m_DecisionAttribute = doubleDataTable.attributes().nominalDecisionAttribute();
        this.m_nDefaultDec = 0;
        for (int i = 1; i < doubleDataTable.getDecisionDistribution().length; i++) {
            if (doubleDataTable.getDecisionDistribution()[i] > doubleDataTable.getDecisionDistribution()[this.m_nDefaultDec]) {
                this.m_nDefaultDec = i;
            }
        }
        makePropertyModifiable("k");
        makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
        makePropertyModifiable("voting");
    }

    public KnnClassifier(Properties properties, NominalAttribute nominalAttribute, VicinityProvider vicinityProvider, CubeBasedNeighboursFilter cubeBasedNeighboursFilter, int[] iArr) throws PropertyConfigurationException {
        super(properties, "k");
        this.m_bSelfLearning = false;
        this.m_VicinityProvider = vicinityProvider;
        this.m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
        this.m_NeighboursFilter = cubeBasedNeighboursFilter;
        this.m_DecisionAttribute = nominalAttribute;
        this.m_nDefaultDec = 0;
        for (int i = 1; i < iArr.length; i++) {
            if (iArr[i] > iArr[this.m_nDefaultDec]) {
                this.m_nDefaultDec = i;
            }
        }
        makePropertyModifiable("k");
        makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
        makePropertyModifiable("voting");
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        writeAbstractParameterisedClassifier(objectOutputStream);
        objectOutputStream.writeObject(this.m_OriginalData);
        objectOutputStream.writeObject(this.m_Transformer);
        objectOutputStream.writeObject(this.m_Metric);
        objectOutputStream.writeInt(this.m_nMaxK);
        objectOutputStream.writeObject(this.m_DecisionAttribute);
        objectOutputStream.writeInt(this.m_nDefaultDec);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        readAbstractParameterisedClassifier(objectInputStream);
        this.m_OriginalData = (ArrayList) objectInputStream.readObject();
        ArrayList<DoubleData> arrayList = this.m_OriginalData;
        this.m_Transformer = (AttributeTransformer) objectInputStream.readObject();
        if (this.m_Transformer != null) {
            arrayList = new ArrayList<>(this.m_OriginalData.size());
            Iterator<DoubleData> it = this.m_OriginalData.iterator();
            while (it.hasNext()) {
                arrayList.add(this.m_Transformer.transformToNew(it.next()));
            }
        }
        this.m_TransformedTrainTable = new ArrayListDoubleDataTable(arrayList);
        this.m_Metric = (Metric) objectInputStream.readObject();
        try {
            this.m_VicinityProvider = new IndexingTreeVicinityProvider(null, this.m_Metric, new TreeIndexer(null).indexing(this.m_TransformedTrainTable.getDataObjects(), this.m_Metric, new EmptyProgress()));
            this.m_bSelfLearning = false;
            this.m_nMaxK = objectInputStream.readInt();
            if (this.m_Metric instanceof AbstractWeightedMetric) {
                this.m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric) this.m_Metric);
            }
            this.m_DecisionAttribute = (NominalAttribute) objectInputStream.readObject();
            this.m_nDefaultDec = objectInputStream.readInt();
        } catch (InterruptedException e) {
            throw new NotSerializableException(e.getMessage());
        } catch (PropertyConfigurationException e2) {
            throw new NotSerializableException(e2.getMessage());
        }
    }

    public void setSelfLearning(boolean z) {
        this.m_bSelfLearning = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v50 */
    /* JADX WARN: Type inference failed for: r13v2 */
    /* JADX WARN: Type inference failed for: r13v3 */
    /* JADX WARN: Type inference failed for: r13v4 */
    /* JADX WARN: Type inference failed for: r13v6 */
    /* JADX WARN: Type inference failed for: r13v7 */
    /* JADX WARN: Type inference failed for: r1v20 */
    /* JADX WARN: Type inference failed for: r1v33 */
    protected void learnOptimalParameterValueCV(DoubleDataTable doubleDataTable, Progress progress) throws PropertyConfigurationException, InterruptedException {
        ArrayList<DoubleData>[] randomStratifiedPartition = doubleDataTable.randomStratifiedPartition(10);
        int[][][] iArr = null;
        progress.set("Learning optimal parameter value using cross-validation", doubleDataTable.noOfObjects());
        int i = 0;
        while (i < randomStratifiedPartition.length) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < randomStratifiedPartition.length; i2++) {
                if (i2 == i) {
                    arrayList2.addAll(randomStratifiedPartition[i2]);
                } else {
                    arrayList.addAll(randomStratifiedPartition[i2]);
                }
            }
            VicinityProvider indexingTreeVicinityProvider = getBoolProperty(INDEXING_PROPERTY_NAME) ? new IndexingTreeVicinityProvider(null, this.m_Metric, new TreeIndexer(null).indexing(arrayList, this.m_Metric, progress)) : new ArrayVicinityProvider(this.m_Metric, arrayList);
            Iterator it = arrayList2.iterator();
            int[][][] iArr2 = iArr;
            while (it.hasNext()) {
                DoubleData doubleData = (DoubleData) it.next();
                double[] classifyWithParameter = classifyWithParameter(doubleData, indexingTreeVicinityProvider.getVicinity(doubleData, this.m_nMaxK));
                if (!iArr2) {
                    iArr2 = new int[classifyWithParameter.length];
                    for (int i3 = 0; i3 < iArr2.length; i3++) {
                        iArr2[i3] = new int[this.m_DecisionAttribute.noOfValues()];
                        for (int i4 = 0; i4 < iArr2[i3].length; i4++) {
                            iArr2[i3][i4] = new int[this.m_DecisionAttribute.noOfValues()];
                        }
                    }
                }
                for (int i5 = 1; i5 < iArr2.length; i5++) {
                    int[] iArr3 = iArr2[i5][this.m_DecisionAttribute.localValueCode(((DoubleDataWithDecision) doubleData).getDecision())];
                    int localValueCode = this.m_DecisionAttribute.localValueCode(classifyWithParameter[i5]);
                    iArr3[localValueCode] = iArr3[localValueCode] + 1;
                }
                progress.step();
                iArr2 = iArr2;
            }
            i++;
            iArr = iArr2;
        }
        ParameterisedTestResult parameterisedTestResult = new ParameterisedTestResult(getParameterName(), this.m_DecisionAttribute, doubleDataTable.getDecisionDistribution(), iArr, new Properties());
        int i6 = 0;
        for (int i7 = 1; i7 < parameterisedTestResult.getParameterRange(); i7++) {
            if (parameterisedTestResult.getClassificationResult(i7).getAccuracy() > parameterisedTestResult.getClassificationResult(i6).getAccuracy()) {
                i6 = i7;
            }
        }
        makePropertyModifiable("k");
        setProperty("k", Integer.toString(i6));
    }

    @Override // rseslib.processing.classification.ClassifierWithDistributedDecision
    public double[] classifyWithDistributedDecision(DoubleData doubleData) throws PropertyConfigurationException {
        if (this.m_Transformer != null) {
            doubleData = this.m_Transformer.transformToNew(doubleData);
        }
        Neighbour[] vicinity = this.m_VicinityProvider.getVicinity(doubleData, getIntProperty("k"));
        boolean boolProperty = getBoolProperty(FILTER_NEIGHBOURS_PROPERTY_NAME);
        if (boolProperty && this.m_NeighboursFilter != null) {
            this.m_NeighboursFilter.markConsistency(doubleData, vicinity);
        }
        double[] dArr = new double[this.m_DecisionAttribute.noOfValues()];
        try {
            Voting valueOf = Voting.valueOf(getProperty("voting"));
            for (int i = 1; i < vicinity.length; i++) {
                int localValueCode = this.m_DecisionAttribute.localValueCode(vicinity[i].neighbour().getDecision());
                if (!boolProperty || vicinity[i].m_bConsistent) {
                    switch ($SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting()[valueOf.ordinal()]) {
                        case 1:
                            dArr[localValueCode] = dArr[localValueCode] + 1.0d;
                            break;
                        case 2:
                            dArr[localValueCode] = dArr[localValueCode] + (1.0d / vicinity[i].dist());
                            break;
                        case 3:
                            dArr[localValueCode] = dArr[localValueCode] + (1.0d / (vicinity[i].dist() * vicinity[i].dist()));
                            break;
                    }
                }
            }
            return dArr;
        } catch (IllegalArgumentException e) {
            throw new PropertyConfigurationException("Unknown voting method: " + getProperty("voting"));
        }
    }

    @Override // rseslib.processing.classification.Classifier
    public double classify(DoubleData doubleData) throws PropertyConfigurationException {
        double[] classifyWithDistributedDecision = classifyWithDistributedDecision(doubleData);
        int i = 0;
        for (int i2 = 1; i2 < classifyWithDistributedDecision.length; i2++) {
            if (classifyWithDistributedDecision[i2] > classifyWithDistributedDecision[i]) {
                i = i2;
            }
        }
        return this.m_DecisionAttribute.globalValueCode(i);
    }

    @Override // rseslib.processing.classification.parameterised.ParameterisedClassifier
    public double[] classifyWithParameter(DoubleData doubleData) throws PropertyConfigurationException {
        Neighbour[] vicinity;
        if (this.m_Transformer != null) {
            doubleData = this.m_Transformer.transformToNew(doubleData);
        }
        if (this.m_bSelfLearning) {
            Neighbour[] vicinity2 = this.m_VicinityProvider.getVicinity(doubleData, this.m_nMaxK + 1);
            vicinity = new Neighbour[vicinity2.length - 1];
            int i = 1;
            while (i < vicinity.length && !doubleData.equals(vicinity2[i].neighbour())) {
                vicinity[i] = vicinity2[i];
                i++;
            }
            while (i < vicinity.length) {
                vicinity[i] = vicinity2[i + 1];
                i++;
            }
        } else {
            vicinity = this.m_VicinityProvider.getVicinity(doubleData, this.m_nMaxK);
        }
        return classifyWithParameter(doubleData, vicinity);
    }

    public double[] classifyWithParameter(DoubleData doubleData, Neighbour[] neighbourArr) throws PropertyConfigurationException {
        boolean boolProperty = getBoolProperty(FILTER_NEIGHBOURS_PROPERTY_NAME);
        if (boolProperty && this.m_NeighboursFilter != null) {
            this.m_NeighboursFilter.markConsistency(doubleData, neighbourArr);
        }
        double[] dArr = new double[this.m_nMaxK + 1];
        double[] dArr2 = new double[this.m_DecisionAttribute.noOfValues()];
        int i = this.m_nDefaultDec;
        dArr[0] = this.m_DecisionAttribute.globalValueCode(i);
        try {
            Voting valueOf = Voting.valueOf(getProperty("voting"));
            int i2 = 1;
            for (int i3 = 1; i3 < neighbourArr.length; i3++) {
                int localValueCode = this.m_DecisionAttribute.localValueCode(neighbourArr[i3].neighbour().getDecision());
                if (!boolProperty || neighbourArr[i3].m_bConsistent) {
                    switch ($SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting()[valueOf.ordinal()]) {
                        case 1:
                            dArr2[localValueCode] = dArr2[localValueCode] + 1.0d;
                            break;
                        case 2:
                            dArr2[localValueCode] = dArr2[localValueCode] + (1.0d / neighbourArr[i3].dist());
                            break;
                        case 3:
                            dArr2[localValueCode] = dArr2[localValueCode] + (1.0d / (neighbourArr[i3].dist() * neighbourArr[i3].dist()));
                            break;
                    }
                }
                if (i3 == neighbourArr.length - 1 || neighbourArr[i3].dist() != neighbourArr[i3 + 1].dist()) {
                    if (i2 < i3) {
                        for (int i4 = 0; i4 < dArr2.length; i4++) {
                            if (dArr2[i4] > dArr2[i]) {
                                i = i4;
                            }
                        }
                    } else if (dArr2[localValueCode] > dArr2[i]) {
                        i = localValueCode;
                    }
                    for (int i5 = i2; i5 <= i3 && i5 < dArr.length; i5++) {
                        dArr[i5] = this.m_DecisionAttribute.globalValueCode(i);
                    }
                    i2 = i3 + 1;
                }
            }
            for (int i6 = i2; i6 < dArr.length; i6++) {
                dArr[i6] = this.m_DecisionAttribute.globalValueCode(i);
            }
            return dArr;
        } catch (IllegalArgumentException e) {
            throw new PropertyConfigurationException("Unknown voting method: " + getProperty("voting"));
        }
    }

    @Override // rseslib.processing.classification.Classifier
    public void calculateStatistics() {
        try {
            if (getBoolProperty("learnOptimalK")) {
                addToStatistics("Optimal k", getProperty("k"));
            }
        } catch (PropertyConfigurationException e) {
        }
    }

    @Override // rseslib.processing.classification.Classifier
    public void resetStatistics() {
    }

    static /* synthetic */ int[] $SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting() {
        int[] iArr = $SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Voting.valuesCustom().length];
        try {
            iArr2[Voting.Equal.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Voting.InverseDistance.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Voting.InverseSquareDistance.ordinal()] = 3;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$rseslib$processing$classification$parameterised$knn$KnnClassifier$Voting = iArr2;
        return iArr2;
    }
}
