Skip to content

Commit

Permalink
Filter out remote model auto redeployment
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Sep 23, 2024
1 parent 6a6cac1 commit 91e4c90
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.core.common.Strings;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
Expand Down Expand Up @@ -257,16 +259,19 @@ private void queryRunningModels(ActionListener<SearchResponse> listener) {
private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeployArrangement) {
String modelId = modelAutoRedeployArrangement.getSearchResponse().getId();
List<String> addedNodes = modelAutoRedeployArrangement.getAddedNodes();
List<String> planningWorkerNodes = (List<String>) modelAutoRedeployArrangement
.getSearchResponse()
.getSourceAsMap()
Map<String, Object> sourceAsMap = modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap();
String functionName = (String) Optional.ofNullable(sourceAsMap.get(MLModel.FUNCTION_NAME_FIELD))
.orElse(sourceAsMap.get(MLModel.ALGORITHM_FIELD));
if (FunctionName.REMOTE == FunctionName.from(functionName)) {
log.info("Skipping redeploying remote model {} as remote model deployment can be done at prediction time.", modelId);
return;
}
List<String> planningWorkerNodes = (List<String>) sourceAsMap
.get(MLModel.PLANNING_WORKER_NODES_FIELD);
Integer autoRedeployRetryTimes = (Integer) modelAutoRedeployArrangement
.getSearchResponse()
.getSourceAsMap()
Integer autoRedeployRetryTimes = (Integer) sourceAsMap
.get(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD);
Boolean deployToAllNodes = (Boolean) Optional
.ofNullable(modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap().get(MLModel.DEPLOY_TO_ALL_NODES_FIELD))
.ofNullable(sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD))
.orElse(false);
// calculate node ids.
String[] nodeIds = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,34 @@ public void test_redeployAModel_with_needRedeployArray_isEmpty() {
mlModelAutoReDeployer.redeployAModel();
}

public void test_buildAutoReloadArrangement_skippingRemoteModel_success() throws Exception {
Settings settings = Settings
.builder()
.put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true)
.put(ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES.getKey(), 3)
.put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true)
.put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), false)
.build();

ClusterService clusterService = mock(ClusterService.class);
when(clusterService.localNode()).thenReturn(localNode);
when(clusterService.getClusterSettings()).thenReturn(getClusterSettings(settings));
mockClusterDataNodes(clusterService);

mlModelAutoReDeployer = spy(
new MLModelAutoReDeployer(clusterService, client, settings, mlModelManager, searchRequestBuilderFactory)
);

SearchResponse searchResponse = buildDeployToAllNodesTrueSearchResponse("RemoteModelResult.json");
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(0);
listener.onResponse(searchResponse);
return null;
}).when(searchRequestBuilder).execute(isA(ActionListener.class));
mlModelAutoReDeployer.buildAutoReloadArrangement(addedNodes, clusterManagerNodeId);
verify(client, never()).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), any(ActionListener.class));
}

private SearchResponse buildDeployToAllNodesTrueSearchResponse(String file) throws Exception {
MLModel mlModel = buildModelWithJsonFile(file);
return createResponseWithModel(mlModel);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"last_deployed_time": 1722954415807,
"model_version": "619",
"created_time": 1722954415642,
"deploy_to_all_nodes": true,
"is_hidden": false,
"description": "This is a test model",
"model_state": "DEPLOYED",
"planning_worker_node_count": 1,
"auto_redeploy_retry_times": 0,
"last_updated_time": 1723691017054,
"name": "my sagemaker model",
"connector_id": "z3kVKJEBAfFjoGUT_Ui7",
"current_worker_node_count": 0,
"model_group_id": "MiJPJ5EBQM-QzppeWrTJ",
"planning_worker_nodes": [
"DecGG5pDQYaqelLMLcIV9Q"
],
"algorithm": "REMOTE"
}

0 comments on commit 91e4c90

Please sign in to comment.