Skip to content

Commit

Permalink
[Feature/agent_framework] Add Delete Connector Step (#211)
Browse files Browse the repository at this point in the history
* Add Delete Connector Step

Signed-off-by: Daniel Widdis <[email protected]>

* Add eclipse core runtime version resolution

Signed-off-by: Daniel Widdis <[email protected]>

* Use JDK17 for spotless

Signed-off-by: Daniel Widdis <[email protected]>

* Add Delete Connector Step

Signed-off-by: Daniel Widdis <[email protected]>

* Add eclipse core runtime version resolution

Signed-off-by: Daniel Widdis <[email protected]>

* Use JDK17 for spotless

Signed-off-by: Daniel Widdis <[email protected]>

* Fetch connector ID from appropriate previous node output

Signed-off-by: Daniel Widdis <[email protected]>

* Fix tests

Signed-off-by: Daniel Widdis <[email protected]>

* Test that actual ID is properly passed

Signed-off-by: Daniel Widdis <[email protected]>

* Update to current setup-java version

Signed-off-by: Daniel Widdis <[email protected]>

* Remove unneeded argument captors

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 18, 2023
1 parent 069c907 commit de8f151
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;

/**
* Step to delete a connector for a remote model
*/
public class DeleteConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(DeleteConnectorStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "delete_connector";

/**
* Instantiate this class
* @param mlClient Machine Learning client to perform the deletion
*/
public DeleteConnectorStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
CompletableFuture<WorkflowData> deleteConnectorFuture = new CompletableFuture<>();

ActionListener<DeleteResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteConnectorFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to delete connector");
deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String connectorId = null;

// Previous Node inputs defines which step the connector ID came from
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> CONNECTOR_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
if (previousNode.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) {
connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString();
}
}

if (connectorId != null) {
mlClient.deleteConnector(connectorId, actionListener);
} else {
deleteConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST)
);
}

return deleteConnectorFuture;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
public class WorkflowStepFactory {

private final Map<String, WorkflowStep> stepMap = new HashMap<>();
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

/**
* Instantiate this class.
Expand All @@ -42,17 +41,6 @@ public WorkflowStepFactory(
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
stepMap.put(NoOpStep.NAME, new NoOpStep());
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client));
Expand All @@ -61,6 +49,7 @@ private void populateMap(
stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(ToolStep.NAME, new ToolStep());
stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient));
Expand All @@ -80,7 +69,7 @@ public WorkflowStep createStep(String type) {

/**
* Gets the step map
* @return the step map
* @return a read-only copy of the step map
*/
public Map<String, WorkflowStep> getStepMap() {
return Map.copyOf(this.stepMap);
Expand Down
8 changes: 8 additions & 0 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@
"connector_id"
]
},
"delete_connector": {
"inputs": [
"connector_id"
],
"outputs":[
"connector_id"
]
},
"register_local_model": {
"inputs":[
"name",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -81,15 +80,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -98,8 +94,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
Collections.emptyMap()
);

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());
assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get("connector_id"));

Expand All @@ -108,14 +103,11 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
public void testCreateConnectorFailure() throws IOException {
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -124,7 +116,7 @@ public void testCreateConnectorFailure() throws IOException {
Collections.emptyMap()
);

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class DeleteConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id");
}

public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException {

String connectorId = randomAlphaOfLength(5);
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
String connectorIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, connectorIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteConnector(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("connector_id", connectorId), "workflowId", "nodeId")),
Map.of("step_1", "connector_id")
);
verify(machineLearningNodeClient).deleteConnector(any(String.class), any());

assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get("connector_id"));
}

public void testNoConnectorIdInOutput() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Required field connector_id is not provided", ex.getCause().getMessage());
}

public void testDeleteConnectorFailure() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteConnector(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("connector_id", "test"), "workflowId", "nodeId")),
Map.of("step_1", "connector_id")
);

verify(machineLearningNodeClient).deleteConnector(any(String.class), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to delete connector", ex.getCause().getMessage());
}
}

0 comments on commit de8f151

Please sign in to comment.