/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.translate;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import java.util.Arrays;
import java.util.stream.LongStream;

public class StackBatchifier
implements Batchifier {
    private static final long serialVersionUID = 1L;

    @Override
    public NDList batchify(NDList[] inputs) {
        int batchSize = inputs.length;
        int numInputKinds = inputs[0].size();
        if (numInputKinds == 0) {
            return new NDList();
        }
        try {
            NDList result = new NDList(numInputKinds);
            for (int i = 0; i < numInputKinds; ++i) {
                NDList inputsOfKind = new NDList(batchSize);
                String inputName = ((NDArray)inputs[0].get(i)).getName();
                for (NDList input : inputs) {
                    inputsOfKind.add((NDArray)input.get(i));
                }
                NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
                stacked.setName(inputName);
                result.add(stacked);
            }
            return result;
        }
        catch (EngineException | IndexOutOfBoundsException e) {
            for (NDList input : inputs) {
                if (input.size() == numInputKinds) continue;
                throw new IllegalArgumentException("You cannot batch data with different numbers of inputs", e);
            }
            for (int i = 0; i < numInputKinds; ++i) {
                Shape kindDataShape = ((NDArray)inputs[0].get(i)).getShape();
                DataType kindDataType = ((NDArray)inputs[0].get(i)).getDataType();
                for (NDList input : inputs) {
                    NDArray currInput = (NDArray)input.get(i);
                    if (!currInput.getShape().equals(kindDataShape)) {
                        throw new IllegalArgumentException("You cannot batch data with different input shapes" + currInput.getShape() + " vs " + kindDataShape, e);
                    }
                    if (currInput.getDataType().equals((Object)kindDataType)) continue;
                    throw new IllegalArgumentException("You cannot batch data with different input data types", e);
                }
            }
            throw e;
        }
    }

    @Override
    public NDList[] unbatchify(NDList inputs) {
        int numInputKinds = inputs.size();
        if (numInputKinds == 0) {
            return new NDList[0];
        }
        int batchSize = Math.toIntExact(inputs.head().size(0));
        if (batchSize == 0) {
            return new NDList[0];
        }
        NDList[] dataList = new NDList[batchSize];
        for (int i = 0; i < batchSize; ++i) {
            dataList[i] = new NDList();
        }
        for (NDArray input : inputs) {
            NDList splitList = input.split(batchSize);
            for (int i = 0; i < batchSize; ++i) {
                NDArray array = ((NDArray)splitList.get(i)).squeeze(0);
                array.setName(input.getName());
                dataList[i].add(array);
            }
        }
        return dataList;
    }

    @Override
    public NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        int batchSize = Math.toIntExact(list.head().size(0));
        numOfSlices = Math.min(numOfSlices, batchSize);
        NDList[] splitted = new NDList[numOfSlices];
        Arrays.setAll(splitted, i -> new NDList());
        for (NDArray nd : list) {
            String name = nd.getName();
            NDList rows = this.split(nd, numOfSlices, evenSplit);
            for (int i2 = 0; i2 < numOfSlices; ++i2) {
                NDArray array = (NDArray)rows.get(i2);
                array.setName(name);
                splitted[i2].add(array);
            }
        }
        return splitted;
    }

    private NDList split(NDArray array, int numOfSlices, boolean evenSplit) {
        int batchSize = Math.toIntExact(array.size(0));
        if (batchSize < numOfSlices) {
            throw new IllegalArgumentException("Batch size(" + batchSize + ") is less then slice number(" + numOfSlices + ").");
        }
        if (evenSplit && batchSize % numOfSlices != 0) {
            throw new IllegalArgumentException("data with shape " + batchSize + " cannot be evenly split into " + numOfSlices + ". Use a batch size that's multiple of " + numOfSlices + " or set even_split=true to allow uneven partitioning of data.");
        }
        if (evenSplit) {
            return array.split(numOfSlices);
        }
        int step = (int)Math.ceil((double)batchSize / (double)numOfSlices);
        long[] indices = LongStream.range(1L, numOfSlices).map(i -> i * (long)step).toArray();
        return array.split(indices);
    }
}

