Skip to content

Commit

Permalink
Implement retryOnConflict on UpdateDataObject
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Sep 12, 2024
1 parent d502ae7 commit 1f15e60
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class UpdateDataObjectRequest {
private final String tenantId;
private final Long ifSeqNo;
private final Long ifPrimaryTerm;
private final int retryOnConflict;
private final ToXContentObject dataObject;

/**
Expand All @@ -33,14 +34,24 @@ public class UpdateDataObjectRequest {
* @param tenantId the tenant id
* @param ifSeqNo the sequence number to match or null if not required
* @param ifPrimaryTerm the primary term to match or null if not required
* @param retryOnConflict number of times to retry an update if a version conflict exists
* @param dataObject the data object
*/
public UpdateDataObjectRequest(String index, String id, String tenantId, Long ifSeqNo, Long ifPrimaryTerm, ToXContentObject dataObject) {
public UpdateDataObjectRequest(
String index,
String id,
String tenantId,
Long ifSeqNo,
Long ifPrimaryTerm,
int retryOnConflict,
ToXContentObject dataObject
) {
this.index = index;
this.id = id;
this.tenantId = tenantId;
this.ifSeqNo = ifSeqNo;
this.ifPrimaryTerm = ifPrimaryTerm;
this.retryOnConflict = retryOnConflict;
this.dataObject = dataObject;
}

Expand Down Expand Up @@ -84,6 +95,14 @@ public Long ifPrimaryTerm() {
return ifPrimaryTerm;
}

/**
* Returns the number of retries on version conflict
* @return the number of retries
*/
public int retryOnConflict() {
return retryOnConflict;
}

/**
* Returns the data object
* @return the data object
Expand All @@ -109,6 +128,7 @@ public static class Builder {
private String tenantId = null;
private Long ifSeqNo = null;
private Long ifPrimaryTerm = null;
private int retryOnConflict = 0;
private ToXContentObject dataObject = null;

/**
Expand Down Expand Up @@ -174,6 +194,16 @@ public Builder ifPrimaryTerm(long term) {
return this;
}

/**
* Retry the update request on a version conflict exception.
* @param retries Number of times to retry, if positive, otherwise will not retry
* @return the updated builder
*/
public Builder retryOnConflict(int retries) {
this.retryOnConflict = retries;
return this;
}

/**
* Add a data object to this builder
* @param dataObject the data object
Expand All @@ -194,7 +224,8 @@ public Builder dataObject(Map<String, Object> dataObjectMap) {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(dataObjectMap);
}};
}
};
return this;
}

Expand All @@ -206,7 +237,15 @@ public UpdateDataObjectRequest build() {
if ((ifSeqNo == null) != (ifPrimaryTerm == null)) {
throw new IllegalArgumentException("Either ifSeqNo and ifPrimaryTerm must both be null or both must be non-null.");
}
return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.ifSeqNo, this.ifPrimaryTerm, this.dataObject);
return new UpdateDataObjectRequest(
this.index,
this.id,
this.tenantId,
this.ifSeqNo,
this.ifPrimaryTerm,
this.retryOnConflict,
this.dataObject
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
if (request.ifPrimaryTerm() != null) {
updateRequest.setIfPrimaryTerm(request.ifPrimaryTerm());
}
if (request.retryOnConflict() > 0) {
updateRequest.retryOnConflict(request.retryOnConflict());
}
UpdateResponse updateResponse = client.update(updateRequest).actionGet();
if (updateResponse == null) {
log.info("Null UpdateResponse");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public void testUpdateDataObjectRequest() {
assertEquals(testDataObject, request.dataObject());
assertNull(request.ifSeqNo());
assertNull(request.ifPrimaryTerm());
assertEquals(0, request.retryOnConflict());
}

@Test
Expand Down Expand Up @@ -92,6 +93,7 @@ public void testUpdateDataObjectRequestConcurrency() {
.dataObject(testDataObject)
.ifSeqNo(testSeqNo)
.ifPrimaryTerm(testPrimaryTerm)
.retryOnConflict(3)
.build();

assertEquals(testIndex, request.index());
Expand All @@ -100,6 +102,7 @@ public void testUpdateDataObjectRequestConcurrency() {
assertEquals(testDataObject, request.dataObject());
assertEquals(testSeqNo, request.ifSeqNo());
assertEquals(testPrimaryTerm, request.ifPrimaryTerm());
assertEquals(3, request.retryOnConflict());

final Builder badSeqNoBuilder = UpdateDataObjectRequest.builder();
assertThrows(IllegalArgumentException.class, () -> badSeqNoBuilder.ifSeqNo(-99));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ public void testUpdateDataObject() throws IOException {
.builder()
.index(TEST_INDEX)
.id(TEST_ID)
.retryOnConflict(3)
.dataObject(testDataObject)
.build();

Expand All @@ -327,6 +328,7 @@ public void testUpdateDataObject() throws IOException {
ArgumentCaptor<UpdateRequest> requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class);
verify(mockedClient, times(1)).update(requestCaptor.capture());
assertEquals(TEST_INDEX, requestCaptor.getValue().index());
assertEquals(3, requestCaptor.getValue().retryOnConflict());
assertEquals(TEST_ID, response.id());

UpdateResponse updateActionResponse = UpdateResponse.fromXContent(response.parser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,7 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
.expressionAttributeNames(expressionAttributeNames)
.expressionAttributeValues(expressionAttributeValues);
UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build();
UpdateItemResponse updateItemResponse = dynamoDbClient.updateItem(updateItemRequest);
Long sequenceNumber = null;
if (updateItemResponse != null
&& updateItemResponse.attributes() != null
&& updateItemResponse.attributes().containsKey(SEQ_NO_KEY)) {
sequenceNumber = Long.parseLong(updateItemResponse.attributes().get(SEQ_NO_KEY).n());
}
Long sequenceNumber = updateItemWithRetryOnConflict(updateItemRequest, request);
String simulatedUpdateResponse = simulateOpenSearchResponse(
request.index(),
request.id(),
Expand All @@ -270,13 +264,6 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
Map.of("result", "updated")
);
return UpdateDataObjectResponse.builder().id(request.id()).parser(createParser(simulatedUpdateResponse)).build();
} catch (ConditionalCheckFailedException ccfe) {
log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), ccfe.getMessage(), ccfe);
// Rethrow
throw new OpenSearchStatusException(
"Document version conflict updating " + request.id() + " in index " + request.index(),
RestStatus.CONFLICT
);
} catch (IOException e) {
log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e);
// Rethrow unchecked exception on update IOException
Expand All @@ -288,6 +275,28 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
}), executor);
}

private Long updateItemWithRetryOnConflict(UpdateItemRequest updateItemRequest, UpdateDataObjectRequest request) {
int retriesRemaining = request.retryOnConflict();
do {
try {
UpdateItemResponse updateItemResponse = dynamoDbClient.updateItem(updateItemRequest);
if (updateItemResponse != null
&& updateItemResponse.attributes() != null
&& updateItemResponse.attributes().containsKey(SEQ_NO_KEY)) {
return Long.parseLong(updateItemResponse.attributes().get(SEQ_NO_KEY).n());
}
} catch (ConditionalCheckFailedException ccfe) {
if (retriesRemaining < 1) {
// Throw exception if retries exhausted
String message = "Document version conflict updating " + request.id() + " in index " + request.index();
log.error(message + ": {}", ccfe.getMessage(), ccfe);
throw new OpenSearchStatusException(message, RestStatus.CONFLICT);
}
}
} while (retriesRemaining-- > 0);
return null; // Should never get here
}

/**
* Deletes data document from DDB. Default tenant ID will be used if tenant ID is not specified.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
if (request.ifPrimaryTerm() != null) {
updateRequestBuilder.ifPrimaryTerm(request.ifPrimaryTerm());
}
if (request.retryOnConflict() > 0) {
updateRequestBuilder.retryOnConflict(request.retryOnConflict());
}
UpdateRequest<Map<String, Object>, ?> updateRequest = updateRequestBuilder.build();
log.info("Updating {} in {}", request.id(), request.index());
UpdateResponse<Map<String, Object>> updateResponse = openSearchClient.update(updateRequest, MAP_DOCTYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ public void updateDataObjectAsync_HappyCase() {
.id(TEST_ID)
.index(TEST_INDEX)
.tenantId(TENANT_ID)
.retryOnConflict(1)
.dataObject(testDataObject)
.build();
Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build());
Expand Down Expand Up @@ -525,6 +526,63 @@ public void testUpdateDataObject_VersionCheck() throws IOException {
assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status());
}

@Test
public void updateDataObjectAsync_VersionCheckRetrySuccess() {
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
.builder()
.id(TEST_ID)
.index(TEST_INDEX)
.tenantId(TENANT_ID)
.retryOnConflict(1)
.dataObject(testDataObject)
.build();
ConditionalCheckFailedException conflictException = ConditionalCheckFailedException.builder().build();
// throw conflict exception on first time, return on second time
Mockito
.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture()))
.thenThrow(conflictException)
.thenReturn(UpdateItemResponse.builder().build());
UpdateDataObjectResponse updateResponse = sdkClient
.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture()
.join();
assertEquals(TEST_ID, updateResponse.id());
UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue();
assertEquals(TEST_ID, updateRequest.id());
assertEquals(TEST_INDEX, updateItemRequest.tableName());
assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s());
assertEquals("foo", updateItemRequest.expressionAttributeValues().get(":source").m().get("data").s());
}

@Test
public void updateDataObjectAsync_VersionCheckRetryFailure() {
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
.builder()
.id(TEST_ID)
.index(TEST_INDEX)
.tenantId(TENANT_ID)
.retryOnConflict(1)
.dataObject(testDataObject)
.build();
ConditionalCheckFailedException conflictException = ConditionalCheckFailedException.builder().build();
// throw conflict exception on first two times, return on third time (that never executes)
Mockito
.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture()))
.thenThrow(conflictException)
.thenThrow(conflictException)
.thenReturn(UpdateItemResponse.builder().build());

CompletableFuture<UpdateDataObjectResponse> future = sdkClient
.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture();

CompletionException ce = assertThrows(CompletionException.class, () -> future.join());
Throwable cause = ce.getCause();
assertEquals(OpenSearchStatusException.class, cause.getClass());
assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status());
}

@Test
public void searchDataObjectAsync_HappyCase() {
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.ml.sdkclient;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -320,6 +321,7 @@ public void testUpdateDataObject() throws IOException {
.join();

assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index());
assertEquals(0, updateRequestCaptor.getValue().retryOnConflict().intValue());
assertEquals(TEST_ID, response.id());

org.opensearch.action.update.UpdateResponse updateActionResponse = org.opensearch.action.update.UpdateResponse
Expand All @@ -336,6 +338,7 @@ public void testUpdateDataObjectWithMap() throws IOException {
.builder()
.index(TEST_INDEX)
.id(TEST_ID)
.retryOnConflict(3)
.dataObject(Map.of("foo", "bar"))
.build();

Expand All @@ -356,6 +359,7 @@ public void testUpdateDataObjectWithMap() throws IOException {
sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join();

assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index());
assertEquals(3, updateRequestCaptor.getValue().retryOnConflict().intValue());
assertEquals(TEST_ID, updateRequestCaptor.getValue().id());
assertEquals("bar", ((Map<String, Object>) updateRequestCaptor.getValue().doc()).get("foo"));
}
Expand Down

0 comments on commit 1f15e60

Please sign in to comment.