/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.anomaly.libsvm;

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.libsvm.LibSVMAnomalyModel;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.common.libsvm.SVMParameters;
import org.tribuo.provenance.ModelProvenance;

public class LibSVMAnomalyTrainer
extends LibSVMTrainer<Event> {
    private static final Logger logger = Logger.getLogger(LibSVMAnomalyTrainer.class.getName());

    protected LibSVMAnomalyTrainer() {
    }

    public LibSVMAnomalyTrainer(SVMParameters<Event> parameters) {
        this(parameters, 12345L);
    }

    public LibSVMAnomalyTrainer(SVMParameters<Event> parameters, long seed) {
        super(parameters, seed);
    }

    public void postConfig() {
        super.postConfig();
        if (!this.svmType.isAnomaly()) {
            throw new IllegalArgumentException("Supplied classification or regression parameters to an anomaly detection SVM.");
        }
    }

    public LibSVMModel<Event> train(Dataset<Event> dataset, Map<String, Provenance> instanceProvenance) {
        for (Pair p : dataset.getOutputInfo().outputCountsIterable()) {
            if (!((String)p.getA()).equals(Event.EventType.ANOMALOUS.toString()) || (Long)p.getB() <= 0L) continue;
            throw new IllegalArgumentException("LibSVMAnomalyTrainer only supports EXPECTED events at training time.");
        }
        return super.train(dataset, instanceProvenance);
    }

    protected LibSVMModel<Event> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> outputIDInfo, List<svm_model> models) {
        return new LibSVMAnomalyModel("svm-anomaly-detection-model", provenance, featureIDMap, outputIDInfo, models);
    }

    protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs, SplittableRandom localRNG) {
        String checkString;
        svm_problem problem = new svm_problem();
        problem.l = outputs[0].length;
        problem.x = features;
        problem.y = outputs[0];
        if (curParams.gamma == 0.0) {
            curParams.gamma = 1.0 / (double)numFeatures;
        }
        if ((checkString = svm.svm_check_parameter((svm_problem)problem, (svm_parameter)curParams)) != null) {
            throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
        }
        svm.rand.setSeed(localRNG.nextLong());
        return Collections.singletonList(svm.svm_train((svm_problem)problem, (svm_parameter)curParams));
    }

    protected Pair<svm_node[][], double[][]> extractData(Dataset<Event> data, ImmutableOutputInfo<Event> outputInfo, ImmutableFeatureMap featureMap) {
        double[][] ys = new double[1][data.size()];
        svm_node[][] xs = new svm_node[data.size()][];
        ArrayList buffer = new ArrayList();
        int i = 0;
        for (Example example : data) {
            ys[0][i] = this.extractOutput((Event)example.getOutput());
            xs[i] = LibSVMAnomalyTrainer.exampleToNodes((Example)example, (ImmutableFeatureMap)featureMap, buffer);
            ++i;
        }
        return new Pair((Object)xs, (Object)ys);
    }

    protected double extractOutput(Event output) {
        if (output.getType() == Event.EventType.EXPECTED) {
            return 1.0;
        }
        return -1.0;
    }
}

