/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import java.util.ArrayList;
import java.util.Map;

public class YoloV8Translator
extends YoloV5Translator {
    private int maxBoxes;

    protected YoloV8Translator(Builder builder) {
        super(builder);
        this.maxBoxes = builder.maxBox;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    @Override
    protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list) {
        NDArray rawResult = (NDArray)list.get(0);
        NDArray reshapedResult = rawResult.transpose();
        Shape shape = reshapedResult.getShape();
        float[] buf = reshapedResult.toFloatArray();
        int numberRows = Math.toIntExact(shape.get(0));
        int nClasses = Math.toIntExact(shape.get(1));
        int padding = nClasses - this.classes.size();
        if (padding != 0 && padding != 4) {
            throw new IllegalStateException("Expected classes: " + (nClasses - 4) + ", got " + this.classes.size());
        }
        ArrayList<Rectangle> boxes = new ArrayList<Rectangle>();
        ArrayList<Float> scores = new ArrayList<Float>();
        ArrayList<Integer> classIds = new ArrayList<Integer>();
        for (int i = numberRows - 1; i > numberRows - this.maxBoxes; --i) {
            int index = i * nClasses;
            float maxClassProb = -1.0f;
            int maxIndex = -1;
            for (int c = 4; c < nClasses; ++c) {
                float classProb = buf[index + c];
                if (!(classProb > maxClassProb)) continue;
                maxClassProb = classProb;
                maxIndex = c;
            }
            maxIndex -= padding;
            if (!(maxClassProb > this.threshold)) continue;
            float xPos = buf[index];
            float yPos = buf[index + 1];
            float w = buf[index + 2];
            float h = buf[index + 3];
            Rectangle rect = new Rectangle(Math.max(0.0f, xPos - w / 2.0f), Math.max(0.0f, yPos - h / 2.0f), w, h);
            boxes.add(rect);
            scores.add(Float.valueOf(maxClassProb));
            classIds.add(maxIndex);
        }
        return this.nms(imageWidth, imageHeight, boxes, classIds, scores);
    }

    public static class Builder
    extends YoloV5Translator.Builder {
        private int maxBox = 8400;

        @Override
        public YoloV8Translator build() {
            if (this.pipeline == null) {
                this.addTransform(array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
            }
            this.validate();
            return new YoloV8Translator(this);
        }

        @Override
        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            this.maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400);
        }
    }
}

