Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Added register a group model step #120

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ private CommonValue() {}
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

/** The transport action name prefix */
public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/";
Expand Down Expand Up @@ -55,7 +63,7 @@ private CommonValue() {}
/** Model Group Id field */
public static final String MODEL_GROUP_ID = "model_group_id";
/** Description field */
public static final String DESCRIPTION = "description";
public static final String DESCRIPTION_FIELD = "description";
/** Connector Id field */
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
Expand All @@ -72,4 +80,10 @@ private CommonValue() {}
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
public static final String BACKEND_ROLES_FIELD = "backend_roles";
/** Access mode for the model */
public static final String MODEL_ACCESS_MODE = "access_mode";
/** Add all backend roles */
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles";
}
22 changes: 7 additions & 15 deletions src/main/java/org/opensearch/flowframework/model/Template.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,19 @@
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.COMPATIBILITY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TEMPLATE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOWS_FIELD;

/**
* The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API.
*/
public class Template implements ToXContentObject {

/** The template field name for template name */
public static final String NAME_FIELD = "name";
/** The template field name for template description */
public static final String DESCRIPTION_FIELD = "description";
/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
/** The template field name for template version information */
public static final String VERSION_FIELD = "version";
/** The template field name for template version */
public static final String TEMPLATE_FIELD = "template";
/** The template field name for template compatibility with OpenSearch versions */
public static final String COMPATIBILITY_FIELD = "compatibility";
/** The template field name for template workflows */
public static final String WORKFLOWS_FIELD = "workflows";

private final String name;
private final String description;
private final String useCase; // probably an ENUM actually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +33,7 @@

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
Expand Down Expand Up @@ -85,7 +86,7 @@ public void onFailure(Exception e) {
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;
List<ConnectorAction> actions = new ArrayList<>();

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
Expand All @@ -95,8 +96,8 @@ public void onFailure(Exception e) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput.MLRegisterModelGroupInputBuilder;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES;
import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ACCESS_MODE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a model group
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "model_group";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public ModelGroupStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {

CompletableFuture<WorkflowData> registerModelGroupFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelGroupResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) {
logger.info("Model group registration successful");
registerModelGroupFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()),
Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model group");
registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelGroupName = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelGroupName = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case BACKEND_ROLES_FIELD:
backendRoles = getBackendRoles(content);
break;
case MODEL_ACCESS_MODE:
modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE);
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES);
break;
default:
break;
}
}
}

if (modelGroupName == null) {
registerModelGroupFuture.completeExceptionally(
new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST)
);
} else {
MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
if (description != null) {
builder.description(description);
}
if (!backendRoles.isEmpty()) {
builder.backendRoles(backendRoles);
}
if (modelAccessMode != null) {
builder.modelAccessMode(modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.isAddAllBackendRoles(isAddAllBackendRoles);
}
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);
}

return registerModelGroupFuture;
}

@Override
public String getName() {
return NAME;
}

@SuppressWarnings("unchecked")
private List<String> getBackendRoles(Map<String, Object> content) {
return (List<String>) content.get(BACKEND_ROLES_FIELD);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
Expand Down Expand Up @@ -114,8 +114,8 @@ public void onFailure(Exception e) {
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private void populateMap(ClusterService clusterService, Client client, MachineLe
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -34,9 +35,6 @@
public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

Expand All @@ -49,6 +47,10 @@ public void setUp() throws Exception {

MockitoAnnotations.openMocks(this);

ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = "foot.test";

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Expand All @@ -57,7 +59,20 @@ public void setUp() throws Exception {
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry("actions", List.of("actions"))
Map.entry(
"actions",
List.of(
new ConnectorAction(
actionType,
method,
url,
null,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }",
null,
null
)
)
)
)
);

Expand Down
Loading
Loading