/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.cluster;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
import org.opensearch.remote.metadata.client.WriteDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;

public class MLSyncUpCron
implements Runnable {
    @Generated
    private static final Logger log = LogManager.getLogger(MLSyncUpCron.class);
    public static final int DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS = 20000;
    private Client client;
    private final SdkClient sdkClient;
    private ClusterService clusterService;
    private DiscoveryNodeHelper nodeHelper;
    private MLIndicesHandler mlIndicesHandler;
    private Encryptor encryptor;
    private volatile Boolean mlConfigInited;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    @VisibleForTesting
    Semaphore updateModelStateSemaphore;

    public MLSyncUpCron(Client client, SdkClient sdkClient, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, Encryptor encryptor, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        this.client = client;
        this.sdkClient = sdkClient;
        this.clusterService = clusterService;
        this.nodeHelper = nodeHelper;
        this.mlIndicesHandler = mlIndicesHandler;
        this.updateModelStateSemaphore = new Semaphore(1);
        this.mlConfigInited = false;
        this.encryptor = encryptor;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    @Override
    public void run() {
        this.initMLConfig();
        if (!this.clusterService.state().metadata().indices().containsKey(".plugins-ml-model")) {
            log.debug("Skipping sync up job - ML model index not found");
            return;
        }
        log.debug("ML sync job starts");
        DiscoveryNode[] allNodes = this.nodeHelper.getAllNodes();
        MLSyncUpInput gatherInfoInput = MLSyncUpInput.builder().getDeployedModels(true).build();
        MLSyncUpNodesRequest gatherInfoRequest = new MLSyncUpNodesRequest(allNodes, gatherInfoInput);
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)gatherInfoRequest, ActionListener.wrap(r -> {
            log.debug("Received sync up responses from nodes");
            List responses = r.getNodes();
            if (r.failures() != null && !r.failures().isEmpty()) {
                log.debug("Received {} failures in the sync up response on nodes. Error messages are {}", (Object)r.failures().size(), (Object)r.failures().stream().map(Throwable::getMessage).collect(Collectors.joining(", ")));
            }
            HashMap<String, Set> modelWorkerNodes = new HashMap<String, Set>();
            HashMap<String, Set> runningDeployModelTasks = new HashMap<String, Set>();
            HashMap<String, Set> deployingModels = new HashMap<String, Set>();
            HashMap expiredModelToNodes = new HashMap();
            for (Object response : responses) {
                String[] runningDeployModelTaskIds;
                String[] runningModelIds;
                String[] deployedModelIds;
                String string = response.getNode().getId();
                log.debug("Processing sync response from node: {}", (Object)string);
                String[] expiredModelIds = response.getExpiredModelIds();
                if (expiredModelIds != null && expiredModelIds.length > 0) {
                    Arrays.stream(expiredModelIds).forEach(modelId -> expiredModelToNodes.computeIfAbsent(modelId, it -> new HashSet()).add(nodeId));
                }
                if ((deployedModelIds = response.getDeployedModelIds()) != null) {
                    for (String modelId2 : deployedModelIds) {
                        Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId2, it -> new HashSet());
                        workerNodes.add(string);
                    }
                }
                if ((runningModelIds = response.getRunningDeployModelIds()) != null) {
                    for (String modelId3 : runningModelIds) {
                        Set workerNodes = deployingModels.computeIfAbsent(modelId3, it -> new HashSet());
                        workerNodes.add(string);
                    }
                }
                if ((runningDeployModelTaskIds = response.getRunningDeployModelTaskIds()) == null) continue;
                for (String taskId : runningDeployModelTaskIds) {
                    Set workerNodes = runningDeployModelTasks.computeIfAbsent(taskId, it -> new HashSet());
                    workerNodes.add(string);
                }
            }
            HashSet<String> modelsToUndeploy = new HashSet<String>();
            for (String string : expiredModelToNodes.keySet()) {
                if (!modelWorkerNodes.containsKey(string) || ((Set)expiredModelToNodes.get(string)).size() != ((Set)modelWorkerNodes.get(string)).size()) continue;
                modelsToUndeploy.add(string);
            }
            for (Map.Entry entry : modelWorkerNodes.entrySet()) {
                String modelId5 = (String)entry.getKey();
                log.debug("will sync model worker nodes for model: {}: {}", (Object)modelId5, (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            for (Map.Entry entry : runningDeployModelTasks.entrySet()) {
                log.debug("will sync running task: {}: {}", entry.getKey(), (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            MLSyncUpInput.MLSyncUpInputBuilder inputBuilder = MLSyncUpInput.builder().syncRunningDeployModelTasks(true).runningDeployModelTasks(runningDeployModelTasks);
            if (modelWorkerNodes.isEmpty()) {
                log.debug("No deployed model found. Will clear model routing on all nodes");
                inputBuilder.clearRoutingTable(true);
            } else {
                inputBuilder.modelRoutingTable(modelWorkerNodes);
            }
            MLSyncUpInput mLSyncUpInput = inputBuilder.build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, mLSyncUpInput);
            this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(re -> {
                log.debug("sync model routing job finished");
                if (!modelsToUndeploy.isEmpty()) {
                    this.undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels);
                    return;
                }
                this.mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
                    if (!res.booleanValue()) {
                        log.error("No response to create ML model index");
                        return;
                    }
                    this.refreshModelState(modelWorkerNodes, deployingModels);
                }, e -> log.error("Failed to init model index", (Throwable)e)));
            }, ex -> log.error("Failed to sync model routing", (Throwable)ex)));
        }, e -> log.error("Failed to sync model routing", (Throwable)e)));
    }

    private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes, Map<String, Set<String>> deployingModels) {
        String[] targetNodeIds = RestActionUtils.getAllNodes(this.clusterService);
        log.debug("Sending requests to undeploy expired models: {}", expiredModels);
        MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest(expiredModels.toArray(new String[expiredModels.size()]), targetNodeIds, null);
        this.client.execute((ActionType)MLUndeployModelsAction.INSTANCE, (ActionRequest)mlUndeployModelsRequest, ActionListener.wrap(r -> {
            MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse();
            if (mlUndeployModelNodesResponse.failures() != null && !mlUndeployModelNodesResponse.failures().isEmpty()) {
                log.debug("Received failures in undeploying expired models", (Object)mlUndeployModelNodesResponse.failures());
            }
            this.mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
                if (!res.booleanValue()) {
                    log.error("No response to create ML model index");
                    return;
                }
                this.refreshModelState(modelWorkerNodes, deployingModels);
            }, e -> log.error("Failed to init model index", (Throwable)e)));
        }, e -> log.error("Failed to undeploy models {}", (Object)expiredModels, e)));
    }

    @VisibleForTesting
    void initMLConfig() {
        if (this.mlConfigInited.booleanValue() || this.mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
            return;
        }
        this.mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
            if (!r.booleanValue()) {
                log.error("Failed to initialize or update ML Config index");
                return;
            }
            GetRequest getRequest = new GetRequest(".plugins-ml-config").id("master_key");
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.get(getRequest, ActionListener.wrap(getResponse -> {
                    if (!getResponse.isExists()) {
                        IndexRequest indexRequest = new IndexRequest(".plugins-ml-config").id("master_key");
                        String masterKey = this.encryptor.generateMasterKey();
                        indexRequest.source((Map)ImmutableMap.of((Object)"master_key", (Object)masterKey, (Object)"create_time", (Object)Instant.now().toEpochMilli()));
                        indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                        indexRequest.opType(DocWriteRequest.OpType.CREATE);
                        this.client.index(indexRequest, ActionListener.wrap(indexResponse -> {
                            log.info("ML configuration initialized successfully");
                            this.encryptor.setMasterKey(null, masterKey);
                            this.mlConfigInited = true;
                        }, e -> log.debug("Failed to save ML encryption master key", (Throwable)e)));
                    } else {
                        String masterKey = (String)getResponse.getSourceAsMap().get("master_key");
                        this.encryptor.setMasterKey(null, masterKey);
                        this.mlConfigInited = true;
                        log.info("ML configuration already initialized, no action needed");
                    }
                }, e -> log.debug("Failed to get ML encryption master key", (Throwable)e)));
            }
        }, e -> log.debug("Failed to init ML config index", (Throwable)e)));
    }

    @VisibleForTesting
    void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Set<String>> deployingModels) {
        if (!this.updateModelStateSemaphore.tryAcquire()) {
            log.debug("Model state refresh already in progress. Skipping this cycle.");
            return;
        }
        try {
            BoolQueryBuilder queryBuilder = new BoolQueryBuilder();
            queryBuilder.filter((QueryBuilder)new TermsQueryBuilder("model_state", Arrays.asList(MLModelState.LOADING.name(), MLModelState.PARTIALLY_LOADED.name(), MLModelState.LOADED.name(), MLModelState.LOAD_FAILED.name(), MLModelState.DEPLOYING.name(), MLModelState.PARTIALLY_DEPLOYED.name(), MLModelState.DEPLOYED.name(), MLModelState.DEPLOY_FAILED.name())));
            SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
            sourceBuilder.query((QueryBuilder)queryBuilder);
            sourceBuilder.size(10000);
            sourceBuilder.fetchSource(new String[]{"tenant_id", "model_state", "algorithm", "deploy_to_all_nodes", "planning_worker_nodes", "planning_worker_node_count", "last_updated_time", "current_worker_node_count"}, null);
            SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder().indices(new String[]{".plugins-ml-model"}).searchSourceBuilder(sourceBuilder).build();
            this.sdkClient.searchDataObjectAsync(searchRequest).whenComplete((r, throwable) -> {
                if (throwable == null) {
                    try {
                        SearchResponse res = r.searchResponse();
                        SearchHit[] hits = res.getHits().getHits();
                        HashMap<String, String> tenantIds = new HashMap<String, String>();
                        HashMap<String, MLModelState> newModelStates = new HashMap<String, MLModelState>();
                        HashMap<String, List<String>> newPlanningWorkerNodes = new HashMap<String, List<String>>();
                        for (SearchHit hit : hits) {
                            MLModelState mlModelState;
                            List planningWorkNodes;
                            String modelId = hit.getId();
                            Map sourceAsMap = hit.getSourceAsMap();
                            if (sourceAsMap.containsKey("tenant_id")) {
                                tenantIds.put(modelId, (String)sourceAsMap.get("tenant_id"));
                            }
                            FunctionName functionName = FunctionName.from((String)((String)sourceAsMap.get("algorithm")));
                            MLModelState state = MLModelState.from((String)((String)sourceAsMap.get("model_state")));
                            Long lastUpdateTime = sourceAsMap.containsKey("last_updated_time") ? (Long)sourceAsMap.get("last_updated_time") : null;
                            int planningWorkerNodeCount = sourceAsMap.containsKey("planning_worker_node_count") ? (Integer)sourceAsMap.get("planning_worker_node_count") : 0;
                            int currentWorkerNodeCountInIndex = sourceAsMap.containsKey("current_worker_node_count") ? (Integer)sourceAsMap.get("current_worker_node_count") : 0;
                            boolean deployToAllNodes = sourceAsMap.containsKey("deploy_to_all_nodes") && (Boolean)sourceAsMap.get("deploy_to_all_nodes") != false;
                            List list = planningWorkNodes = sourceAsMap.containsKey("planning_worker_nodes") ? (List)sourceAsMap.get("planning_worker_nodes") : new ArrayList();
                            if (deployToAllNodes) {
                                DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes(functionName);
                                planningWorkerNodeCount = eligibleNodes.length;
                                List eligibleNodeIds = Arrays.stream(eligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toList());
                                if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) {
                                    newPlanningWorkerNodes.put(modelId, eligibleNodeIds);
                                }
                            }
                            if ((mlModelState = this.getNewModelState(deployingModels, modelWorkerNodes, modelId, state, lastUpdateTime, planningWorkerNodeCount, currentWorkerNodeCountInIndex)) == null) continue;
                            newModelStates.put(modelId, mlModelState);
                        }
                        this.bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes, tenantIds);
                    }
                    catch (Exception e) {
                        log.error("Failed to parse model search response", (Throwable)e);
                        this.updateModelStateSemaphore.release();
                    }
                } else {
                    Exception e = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[]{OpenSearchStatusException.class});
                    this.updateModelStateSemaphore.release();
                    log.error("Failed to search models", (Throwable)e);
                }
            });
        }
        catch (Exception e) {
            this.updateModelStateSemaphore.release();
            log.error("Failed to refresh model state", (Throwable)e);
        }
    }

    private MLModelState getNewModelState(Map<String, Set<String>> deployingModels, Map<String, Set<String>> modelWorkerNodes, String modelId, MLModelState state, Long lastUpdateTime, int planningWorkerNodeCount, int currentWorkerNodeCountInIndex) {
        int currentWorkerNodeCount;
        Set<String> deployModelTaskNodes = deployingModels.get(modelId);
        if (deployModelTaskNodes != null && !deployModelTaskNodes.isEmpty() && state != MLModelState.DEPLOYING) {
            return MLModelState.DEPLOYING;
        }
        int n = currentWorkerNodeCount = modelWorkerNodes.containsKey(modelId) ? modelWorkerNodes.get(modelId).size() : 0;
        if (currentWorkerNodeCount == 0 && state != MLModelState.DEPLOY_FAILED && (state != MLModelState.DEPLOYING || lastUpdateTime == null || lastUpdateTime + 20000L <= Instant.now().toEpochMilli())) {
            return MLModelState.DEPLOY_FAILED;
        }
        if (currentWorkerNodeCount > 0) {
            if (currentWorkerNodeCount < planningWorkerNodeCount && (state != MLModelState.PARTIALLY_DEPLOYED || currentWorkerNodeCountInIndex != currentWorkerNodeCount)) {
                return MLModelState.PARTIALLY_DEPLOYED;
            }
            if (planningWorkerNodeCount > 0 && currentWorkerNodeCount >= planningWorkerNodeCount && state != MLModelState.DEPLOYED) {
                if (currentWorkerNodeCount > planningWorkerNodeCount) {
                    log.warn("Model {} deployed on more nodes [{}] than planning worker node [{}]", (Object)modelId, (Object)currentWorkerNodeCount, (Object)planningWorkerNodeCount);
                }
                return MLModelState.DEPLOYED;
            }
        }
        return null;
    }

    private void bulkUpdateModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, MLModelState> newModelStates, Map<String, List<String>> newPlanningWorkNodes, Map<String, String> tenantIds) {
        HashSet<String> updatedModelIds = new HashSet<String>();
        updatedModelIds.addAll(newModelStates.keySet());
        updatedModelIds.addAll(newPlanningWorkNodes.keySet());
        if (!updatedModelIds.isEmpty()) {
            BulkDataObjectRequest bulkUpdateRequest = BulkDataObjectRequest.builder().globalIndex(".plugins-ml-model").build();
            for (String modelId : updatedModelIds) {
                Instant now = Instant.now();
                HashMap<String, Object> updateDocument = new HashMap<String, Object>();
                if (newModelStates.containsKey(modelId)) {
                    updateDocument.put("model_state", newModelStates.get(modelId).name());
                }
                if (newPlanningWorkNodes.containsKey(modelId)) {
                    updateDocument.put("planning_worker_nodes", newPlanningWorkNodes.get(modelId));
                    updateDocument.put("planning_worker_node_count", newPlanningWorkNodes.get(modelId).size());
                }
                updateDocument.put("last_updated_time", now.toEpochMilli());
                Set<String> workerNodes = modelWorkerNodes.get(modelId);
                int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size();
                updateDocument.put("current_worker_node_count", currentWorkNodeCount);
                UpdateDataObjectRequest updateRequest = ((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)UpdateDataObjectRequest.builder().tenantId(tenantIds.get(modelId))).id(modelId)).dataObject(updateDocument).build();
                bulkUpdateRequest.add((WriteDataObjectRequest)updateRequest);
            }
            bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            log.info("Refresh model state: {}", newModelStates);
            this.sdkClient.bulkDataObjectAsync(bulkUpdateRequest).whenComplete((r, throwable) -> {
                this.updateModelStateSemaphore.release();
                if (throwable != null) {
                    Exception e = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[]{OpenSearchStatusException.class});
                    log.error("Failed to bulk update model state", (Throwable)e);
                } else {
                    log.debug("Refresh model state successfully");
                }
            });
        } else {
            this.updateModelStateSemaphore.release();
        }
    }
}

