/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sequence.viterbi;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.viterbi.LabelFeatureExtractor;
import org.tribuo.classification.sequence.viterbi.ViterbiModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.sequence.ImmutableSequenceDataset;
import org.tribuo.sequence.MutableSequenceDataset;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;

public final class ViterbiTrainer
implements SequenceTrainer<Label> {
    @Config(mandatory=true, description="Inner trainer for each sequence element.")
    private Trainer<Label> trainer;
    @Config(mandatory=true, description="Feature extractor to pull in surrounding label features.")
    private LabelFeatureExtractor labelFeatureExtractor;
    @Config(mandatory=true, description="Number of candidate paths.")
    private int stackSize;
    @Config(mandatory=true, description="Score aggregation function.")
    private ViterbiModel.ScoreAggregation scoreAggregation;
    private int trainInvocationCounter = 0;

    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, ViterbiModel.ScoreAggregation scoreAggregation) {
        this(trainer, labelFeatureExtractor, -1, scoreAggregation);
    }

    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ViterbiModel.ScoreAggregation scoreAggregation) {
        this.trainer = trainer;
        this.labelFeatureExtractor = labelFeatureExtractor;
        this.stackSize = stackSize;
        this.scoreAggregation = scoreAggregation;
    }

    private ViterbiTrainer() {
    }

    public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) {
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        if (this.stackSize == -1) {
            this.stackSize = dataset.getOutputIDInfo().size();
        }
        if (dataset instanceof ImmutableSequenceDataset) {
            dataset = new MutableSequenceDataset((ImmutableSequenceDataset)dataset);
        }
        if (!(dataset instanceof MutableSequenceDataset)) {
            throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName());
        }
        for (SequenceExample sequenceExample : dataset) {
            ArrayList<Label> labels = new ArrayList<Label>();
            for (Example example : sequenceExample) {
                List<Feature> labelFeatures = this.extractFeatures(labels, (MutableSequenceDataset<Label>)dataset, 1.0);
                example.addAll(labelFeatures);
                labels.add((Label)example.getOutput());
            }
        }
        TrainerProvenance trainerProvenance = this.getProvenance();
        ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)dataset.getProvenance(), trainerProvenance, runProvenance);
        ++this.trainInvocationCounter;
        Dataset flatData = dataset.getFlatDataset();
        Model model = this.trainer.train(flatData);
        return new ViterbiModel("viterbi+" + model.getName(), provenance, (Model<Label>)model, this.labelFeatureExtractor, this.stackSize, this.scoreAggregation);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    private List<Feature> extractFeatures(List<Label> labels, MutableSequenceDataset<Label> dataset, double value) {
        ArrayList<Feature> labelFeatures = new ArrayList<Feature>();
        for (Feature labelFeature : this.labelFeatureExtractor.extractFeatures(labels, value)) {
            dataset.getFeatureMap().add(labelFeature.getName(), labelFeature.getValue());
            labelFeatures.add(labelFeature);
        }
        return labelFeatures;
    }

    public String toString() {
        return "ViterbiTrainer(innerTrainer=" + this.trainer.toString() + ",labelFeatureExtractor=" + this.labelFeatureExtractor.toString() + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((SequenceTrainer)this);
    }
}

