diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java index d7536a5a7..7d5750c2b 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java @@ -197,31 +197,17 @@ private List getIndicesUsingModel(ClusterState clusterState, UpdateModel .filter(entry -> entry.getValue() != null) .filter(entry -> { Object properties = entry.getValue().getSourceAsMap().get("properties"); - if (properties == null || properties instanceof Map == false) { + if ((properties instanceof Map) == false) { return false; } - Map propertiesMap = (Map) properties; - return propertiesMapContainsModel(propertiesMap, task.getModelId()); + Map propertiesMap = (Map) properties; + return propertiesMap.values() + .stream() + .filter(obj -> obj instanceof Map) + .anyMatch(obj -> task.getModelId().equals(((Map) obj).get(MODEL_ID))); }) .map(Map.Entry::getKey) .collect(toList()); } - - private boolean propertiesMapContainsModel(Map propertiesMap, String modelId) { - for (Map.Entry fieldsEntry : propertiesMap.entrySet()) { - if (fieldsEntry.getKey() != null && fieldsEntry.getValue() instanceof Map) { - Map innerMap = (Map) fieldsEntry.getValue(); - for (Map.Entry innerEntry : innerMap.entrySet()) { - // If model is in use, fail delete model request - if (innerEntry.getKey().equals(MODEL_ID) - && innerEntry.getValue() instanceof String - && innerEntry.getValue().equals(modelId)) { - return true; - } - } - } - } - return false; - } } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index d87a63ad5..f9c0161d6 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.knn; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -37,7 +38,11 @@ import org.opensearch.test.hamcrest.OpenSearchAssertions; import java.io.IOException; -import java.util.*; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Map; import java.util.concurrent.ExecutionException; import static org.mockito.Mockito.when; @@ -184,7 +189,7 @@ protected void addDoc(String index, String docId, String fieldName, String dummy /** * Index a new model */ - protected void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { + protected void writeModelToModelSystemIndex(Model model) throws IOException, ExecutionException, InterruptedException { ModelMetadata modelMetadata = model.getModelMetadata(); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -213,6 +218,22 @@ protected void addDoc(Model model) throws IOException, ExecutionException, Inter assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK); } + // Add a new model to ModelDao + protected void addModel(Model model) throws IOException { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + modelDao.put(model, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + assertTrue(indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK); + } + + @Override + public void onFailure(Exception e) { + fail("Failed to add model: " + e); + } + }); + } + /** * Run a search against a k-NN index */ diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 3a25c3064..75c523332 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -18,7 +18,6 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.core.action.ActionListener; import org.opensearch.action.DocWriteResponse; @@ -30,12 +29,9 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.knn.KNNSingleNodeTestCase; -import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -65,11 +61,7 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.PROPERTIES; -import static org.opensearch.knn.common.KNNConstants.TYPE; -import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; public class ModelDaoTests extends KNNSingleNodeTestCase { @@ -152,7 +144,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti modelBlob, modelId ); - addDoc(model); + writeModelToModelSystemIndex(model); assertEquals(model, modelDao.get(modelId)); assertNotNull(modelDao.getHealthStatus()); @@ -172,7 +164,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti modelBlob, modelId ); - addDoc(model); + writeModelToModelSystemIndex(model); assertEquals(model, modelDao.get(modelId)); assertNotNull(modelDao.getHealthStatus()); } @@ -450,7 +442,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti modelBlob, modelId ); - addDoc(model); + writeModelToModelSystemIndex(model); assertEquals(model, modelDao.get(modelId)); // Get model during training @@ -469,7 +461,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti null, modelId ); - addDoc(model); + writeModelToModelSystemIndex(model); assertEquals(model, modelDao.get(modelId)); } @@ -629,91 +621,6 @@ public void testDelete() throws IOException, InterruptedException { assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS)); } - // Test Delete Model when the model is in use by an index - public void testDeleteModelInUse() throws IOException, ExecutionException, InterruptedException { - String modelId = "test-model-id-training"; - ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); - byte[] modelBlob = "deleteModel".getBytes(); - int dimension = 2; - createIndex(MODEL_INDEX_NAME); - - Model model = new Model( - new ModelMetadata( - KNNEngine.DEFAULT, - SpaceType.DEFAULT, - dimension, - ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), - "", - "", - "", - MethodComponentContext.EMPTY - ), - modelBlob, - modelId - ); - - // created model and added it to index - addDoc(model); - - String testIndex = "test-index"; - String testField = "test-field"; - - /* - Constructs the following json: - { - "properties": { - "test-field": { - "type": "knn_vector", - "model_id": "test-model-id-training" - } - } - } - */ - XContentBuilder mappings = XContentFactory.jsonBuilder() - .startObject() - .startObject(PROPERTIES) - .startObject(testField) - .field(TYPE, TYPE_KNN_VECTOR) - .field(MODEL_ID, modelId) - .endObject() - .endObject() - .endObject(); - - XContentBuilder settings = XContentFactory.jsonBuilder().startObject().field(TestUtils.INDEX_KNN, "true").endObject(); - - // Create index using model - CreateIndexRequestBuilder createIndexRequestBuilder = client().admin() - .indices() - .prepareCreate(testIndex) - .setMapping(mappings) - .setSettings(settings); - createIndex(testIndex, createIndexRequestBuilder); - - CountDownLatch latch = new CountDownLatch(1); - modelDao.delete(modelId, new ActionListener() { - @Override - public void onResponse(DeleteModelResponse deleteModelResponse) { - fail("Received delete model response when the request should have failed."); - } - - @Override - public void onFailure(Exception e) { - assertTrue(e instanceof DeleteModelException); - assertEquals( - String.format( - "Cannot delete model [%s]. Model is in use by the following indices [%s], which must be deleted first.", - modelId, - testIndex - ), - e.getMessage() - ); - latch.countDown(); - } - }); - assertTrue(latch.await(60, TimeUnit.SECONDS)); - } - // Test Delete Model when modelId is in Model Graveyard (previous delete model request which failed to // remove modelId from model graveyard). But, the model does not exist public void testDeleteModelWithModelInGraveyardModelDoesNotExist() throws InterruptedException { @@ -772,7 +679,7 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe ); // created model and added it to index - addDoc(model); + writeModelToModelSystemIndex(model); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -814,7 +721,7 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti ); // created model and added it to index - addDoc(model); + writeModelToModelSystemIndex(model); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index cd60d566c..5be907ebd 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -189,7 +189,7 @@ public void testCheckBlock() { assertNull(updateModelGraveyardTransportAction.checkBlock(null, null)); } - public void testGetIndicesUsingModel() throws IOException, ExecutionException, InterruptedException { + public void testClusterManagerOperation_GetIndicesUsingModel() throws IOException, ExecutionException, InterruptedException { // Get update transport action UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() .getInstance(UpdateModelGraveyardTransportAction.class); @@ -217,7 +217,7 @@ public void testGetIndicesUsingModel() throws IOException, ExecutionException, I ); // created model and added it to index - addDoc(model); + addModel(model); // Create basic index (not using k-NN) String testIndex1 = "test-index1"; @@ -336,7 +336,7 @@ public void testGetIndicesUsingModel() throws IOException, ExecutionException, I ); } - public void updateModelGraveyardAndAssertNoError( + private void updateModelGraveyardAndAssertNoError( UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction, UpdateModelGraveyardRequest updateModelGraveyardRequest ) throws InterruptedException { @@ -355,7 +355,7 @@ public void updateModelGraveyardAndAssertNoError( assertTrue(countDownLatch.await(60, TimeUnit.SECONDS)); } - public void updateModelGraveyardAndAssertDeleteModelException( + private void updateModelGraveyardAndAssertDeleteModelException( UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction, UpdateModelGraveyardRequest updateModelGraveyardRequest, String indicesPresentInException