/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.engine.faiss;

import com.google.common.collect.ImmutableSet;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.Generated;
import org.opensearch.common.TriFunction;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.DefaultHnswSearchContext;
import org.opensearch.knn.index.engine.Encoder;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.faiss.AbstractFaissMethod;
import org.opensearch.knn.index.engine.faiss.FaissFlatEncoder;
import org.opensearch.knn.index.engine.faiss.FaissHNSWPQEncoder;
import org.opensearch.knn.index.engine.faiss.FaissSQEncoder;
import org.opensearch.knn.index.engine.faiss.MethodAsMapBuilder;
import org.opensearch.knn.index.engine.faiss.QFrameBitEncoder;
import org.opensearch.remoteindexbuild.model.RemoteFaissHNSWIndexParameters;
import org.opensearch.remoteindexbuild.model.RemoteIndexParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FaissHNSWMethod
extends AbstractFaissMethod {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(FaissHNSWMethod.class);
    private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of((Object)((Object)VectorDataType.FLOAT), (Object)((Object)VectorDataType.BINARY), (Object)((Object)VectorDataType.BYTE));
    public static final List<SpaceType> SUPPORTED_SPACES = Arrays.asList(SpaceType.UNDEFINED, SpaceType.HAMMING, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.COSINESIMIL);
    private static final MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext("flat", Collections.emptyMap());
    static final Encoder FLAT_ENCODER = new FaissFlatEncoder();
    static final Encoder SQ_ENCODER = new FaissSQEncoder();
    static final Encoder HNSW_PQ_ENCODER = new FaissHNSWPQEncoder();
    static final Encoder QFRAME_BIT_ENCODER = new QFrameBitEncoder();
    static final Map<String, Encoder> SUPPORTED_ENCODERS = Map.of(FLAT_ENCODER.getName(), FLAT_ENCODER, SQ_ENCODER.getName(), SQ_ENCODER, HNSW_PQ_ENCODER.getName(), HNSW_PQ_ENCODER, QFRAME_BIT_ENCODER.getName(), QFRAME_BIT_ENCODER);
    static final MethodComponent HNSW_COMPONENT = FaissHNSWMethod.initMethodComponent();

    public FaissHNSWMethod() {
        super(HNSW_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext());
    }

    private static MethodComponent initMethodComponent() {
        return MethodComponent.Builder.builder("hnsw").addSupportedDataTypes(SUPPORTED_DATA_TYPES).addParameter("m", new Parameter.IntegerParameter("m", KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, (v, context) -> v > 0)).addParameter("ef_construction", new Parameter.IntegerParameter("ef_construction", KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, (v, context) -> v > 0)).addParameter("ef_search", new Parameter.IntegerParameter("ef_search", KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, (v, context) -> v > 0)).addParameter("encoder", FaissHNSWMethod.initEncoderParameter()).setKnnLibraryIndexingContextGenerator((TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, KNNLibraryIndexingContext>)((TriFunction)(methodComponent, methodComponentContext, knnMethodConfigContext) -> {
            MethodAsMapBuilder methodAsMapBuilder = MethodAsMapBuilder.builder("HNSW", methodComponent, methodComponentContext, knnMethodConfigContext).addParameter("m", "", "").addParameter("encoder", ",", "");
            return FaissHNSWMethod.adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
        })).build();
    }

    private static Parameter.MethodComponentContextParameter initEncoderParameter() {
        return new Parameter.MethodComponentContextParameter("encoder", DEFAULT_ENCODER_CONTEXT, SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)));
    }

    @Override
    protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
        return trainingConfigValidationInput -> {
            KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
            TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
            if (!this.isEncoderSpecified(knnMethodContext)) {
                return builder.build();
            }
            Encoder encoder = SUPPORTED_ENCODERS.get(this.getEncoderName(knnMethodContext));
            if (encoder == null) {
                return builder.build();
            }
            return encoder.validateEncoderConfig((TrainingConfigValidationInput)trainingConfigValidationInput);
        };
    }

    public RemoteIndexParameters createRemoteIndexingParameters(Map<String, Object> parameters) {
        RemoteFaissHNSWIndexParameters.RemoteFaissHNSWIndexParametersBuilder builder = RemoteFaissHNSWIndexParameters.builder();
        builder.algorithm("hnsw");
        builder.spaceType(FaissHNSWMethod.getStringFromMap(parameters, "spaceType"));
        Map innerParameters = (Map)parameters.get("parameters");
        builder.efConstruction(FaissHNSWMethod.getIntegerFromMap(innerParameters, "ef_construction").intValue());
        builder.efSearch(FaissHNSWMethod.getIntegerFromMap(innerParameters, "ef_search").intValue());
        builder.m(FaissHNSWMethod.getIntegerFromMap(innerParameters, "m").intValue());
        return builder.build();
    }

    static boolean supportsRemoteIndexBuild(Map<String, Object> parameters) {
        try {
            VectorDataType vectorDataType = FaissHNSWMethod.extractVectorDataType(parameters);
            Map<String, Object> encoderMap = FaissHNSWMethod.extractEncoderMap(parameters);
            if (FaissHNSWMethod.isFloat32Index(vectorDataType, encoderMap)) {
                return true;
            }
            if (FaissHNSWMethod.isFloat16Index(vectorDataType, parameters)) {
                return true;
            }
            if (FaissHNSWMethod.isBinaryIndex(vectorDataType, encoderMap)) {
                return true;
            }
            if (FaissHNSWMethod.isQuantizedIndex(vectorDataType, encoderMap)) {
                return true;
            }
            return FaissHNSWMethod.isByteIndex(vectorDataType, encoderMap);
        }
        catch (Exception e) {
            log.warn(e.getMessage());
            return false;
        }
    }

    private static boolean isFloat32Index(VectorDataType vectorDataType, Map<String, Object> encoderMap) {
        try {
            if (vectorDataType != VectorDataType.FLOAT) {
                return false;
            }
            String encoder = FaissHNSWMethod.getStringFromMap(encoderMap, "name");
            return encoder.equals("flat");
        }
        catch (Exception e) {
            log.debug(e.getMessage());
            return false;
        }
    }

    public static boolean isFloat16Index(VectorDataType vectorDataType, Map<String, Object> parameters) {
        try {
            if (vectorDataType != VectorDataType.FLOAT) {
                return false;
            }
            Map<String, Object> encoderMap = FaissHNSWMethod.extractEncoderMap(parameters);
            String encoder = FaissHNSWMethod.getStringFromMap(encoderMap, "name");
            return encoder.equals("sq");
        }
        catch (Exception e) {
            log.debug(e.getMessage());
            return false;
        }
    }

    private static boolean isBinaryIndex(VectorDataType vectorDataType, Map<String, Object> encoderMap) {
        try {
            return vectorDataType == VectorDataType.BINARY && FaissHNSWMethod.getStringFromMap(encoderMap, "name").equals("flat");
        }
        catch (Exception e) {
            log.warn(e.getMessage());
            return false;
        }
    }

    private static boolean isQuantizedIndex(VectorDataType vectorDataType, Map<String, Object> encoderMap) {
        try {
            if (vectorDataType != VectorDataType.FLOAT) {
                return false;
            }
            return encoderMap.isEmpty();
        }
        catch (Exception e) {
            log.debug(e.getMessage());
            return false;
        }
    }

    private static boolean isByteIndex(VectorDataType vectorDataType, Map<String, Object> encoderMap) {
        try {
            if (vectorDataType != VectorDataType.BYTE) {
                return false;
            }
            String encoder = FaissHNSWMethod.getStringFromMap(encoderMap, "name");
            return encoder.equals("flat");
        }
        catch (Exception e) {
            log.debug(e.getMessage());
            return false;
        }
    }

    private static VectorDataType extractVectorDataType(Map<String, Object> parameters) {
        String dataType = FaissHNSWMethod.getStringFromMap(parameters, "data_type");
        VectorDataType vectorDataType = VectorDataType.get(dataType);
        return vectorDataType;
    }

    private static Map<String, Object> extractEncoderMap(Map<String, Object> parameters) {
        Map innerMap = (Map)parameters.get("parameters");
        Map encoderMap = (Map)innerMap.get("encoder");
        return encoderMap;
    }

    private static Integer getIntegerFromMap(Map<String, Object> map, String key) {
        Object value = map.get(key);
        if (value instanceof Integer) {
            return (Integer)value;
        }
        if (value instanceof String) {
            return Integer.parseInt((String)value);
        }
        throw new IllegalArgumentException("Could not parse value for key: " + key + " and map: " + String.valueOf(map));
    }

    private static String getStringFromMap(Map<String, Object> map, String key) throws IllegalArgumentException {
        Object value = map.get(key);
        if (value instanceof String) {
            return (String)value;
        }
        throw new IllegalArgumentException("Could not parse value for key: " + key + " and map: " + String.valueOf(map));
    }
}

