diff --git a/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java b/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java index 0d2549f43eeaf..be3ae441e2b06 100644 --- a/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java @@ -64,6 +64,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentMap; import java.util.stream.Collectors; @@ -98,9 +99,10 @@ public class GatewayAllocator implements ExistingShardsAllocator { private final ConcurrentMap> asyncFetchStore = ConcurrentCollections.newConcurrentMap(); private Set lastSeenEphemeralIds = Collections.emptySet(); - - private final ConcurrentMap> asyncBatchFetchStarted = ConcurrentCollections.newConcurrentMap(); - private final ConcurrentMap> asyncBatchFetchStore = ConcurrentCollections.newConcurrentMap(); + private final ConcurrentMap startedShardBatchLookup = ConcurrentCollections.newConcurrentMap(); + private final ConcurrentMap batchIdToStartedShardBatch = ConcurrentCollections.newConcurrentMap(); + private final ConcurrentMap storeShardBatchLookup = ConcurrentCollections.newConcurrentMap(); + private final ConcurrentMap batchIdToStoreShardBatch = ConcurrentCollections.newConcurrentMap(); @Inject public GatewayAllocator( @@ -125,6 +127,10 @@ public void cleanCaches() { asyncFetchStarted.clear(); Releasables.close(asyncFetchStore.values()); asyncFetchStore.clear(); + batchIdToStartedShardBatch.clear(); + batchIdToStoreShardBatch.clear(); + startedShardBatchLookup.clear(); + storeShardBatchLookup.clear(); } // for tests @@ -132,8 +138,8 @@ protected GatewayAllocator() { this.rerouteService = null; this.primaryShardAllocator = null; this.replicaShardAllocator = null; - this.batchStartedAction=null; - this.primaryBatchShardAllocator =null; + this.batchStartedAction = null; + this.primaryBatchShardAllocator = null; this.batchStoreAction = null; this.replicaBatchShardAllocator = null; } @@ -155,6 +161,7 @@ public void applyStartedShards(final List startedShards, final Rou for (ShardRouting startedShard : startedShards) { Releasables.close(asyncFetchStarted.remove(startedShard.shardId())); Releasables.close(asyncFetchStore.remove(startedShard.shardId())); + safelyRemoveShardFromBatch(startedShard); } } @@ -163,6 +170,7 @@ public void applyFailedShards(final List failedShards, final Routin for (FailedShard failedShard : failedShards) { Releasables.close(asyncFetchStarted.remove(failedShard.getRoutingEntry().shardId())); Releasables.close(asyncFetchStore.remove(failedShard.getRoutingEntry().shardId())); + safelyRemoveShardFromBatch(failedShard.getRoutingEntry()); } } @@ -194,81 +202,97 @@ public void allocateUnassigned( } @Override - public void allocateUnassignedBatch(final RoutingAllocation allocation, boolean primary){ + public void allocateUnassignedBatch(final RoutingAllocation allocation, boolean primary) { // create batches for unassigned shards createBatches(allocation, primary); + assert primaryBatchShardAllocator != null; + assert replicaBatchShardAllocator != null; if (primary) { - asyncBatchFetchStarted.keySet().forEach(batch -> primaryBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShards(), allocation)); - } - else { - asyncBatchFetchStore.keySet().forEach(batch -> replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShards(), allocation)); + batchIdToStartedShardBatch.values().forEach(shardsBatch -> primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation)); + } else { + batchIdToStoreShardBatch.values().forEach(batch -> replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation)); } } private void createBatches(RoutingAllocation allocation, boolean primary) { RoutingNodes.UnassignedShards unassigned = allocation.routingNodes().unassigned(); // fetch all current batched shards - Set currentBatchedShards; - if (primary) { - currentBatchedShards = asyncBatchFetchStarted.keySet().stream().flatMap(shardsBatch -> shardsBatch.getBatchedShards().stream()).collect(Collectors.toSet()); - } else { - currentBatchedShards = asyncBatchFetchStore.keySet().stream().flatMap(shardsBatch -> shardsBatch.getBatchedShards().stream()).collect(Collectors.toSet()); - } + Set currentBatchedShards = primary? startedShardBatchLookup.keySet() : storeShardBatchLookup.keySet(); Set shardsToBatch = Sets.newHashSet(); // add all unassigned shards to the batch if they are not already in a batch unassigned.forEach(shardRouting -> { - if ((currentBatchedShards.contains(shardRouting) == false) && (shardRouting.primary() == primary)) { + if ((currentBatchedShards.contains(shardRouting.shardId()) == false) && (shardRouting.primary() == primary)) { assert shardRouting.unassigned(); shardsToBatch.add(shardRouting); } }); Iterator iterator = shardsToBatch.iterator(); long batchSize = MAX_BATCH_SIZE; - Map addToCurrentBatch = new HashMap<>(); + Map addToCurrentBatch = new HashMap<>(); while (iterator.hasNext()) { ShardRouting currentShard = iterator.next(); if (batchSize > 0) { - addToCurrentBatch.put(currentShard, IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(currentShard.index()).getSettings())); + ShardBatchEntry shardBatchEntry = new ShardBatchEntry(IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(currentShard.index()).getSettings()) + , currentShard); + addToCurrentBatch.put(currentShard.shardId(), shardBatchEntry); batchSize--; iterator.remove(); } // add to batch if batch size full or last shard in unassigned list if (batchSize == 0 || iterator.hasNext() == false) { String batchUUId = UUIDs.base64UUID(); - ShardsBatch shardsBatch = new ShardsBatch(batchUUId, addToCurrentBatch); - Map shardIdsMap = addToCurrentBatch.entrySet().stream().collect(Collectors.toMap( - entry -> entry.getKey().shardId(), - Map.Entry::getValue - )); - if(primary) { - asyncBatchFetchStarted.computeIfAbsent( - shardsBatch, - batch -> new InternalBatchAsyncFetch<>( - logger, - "batch_shards_started", - shardIdsMap, - this.batchStartedAction, - batch.getBatchId() - )); - } - else { - asyncBatchFetchStore.computeIfAbsent( - shardsBatch, - batch -> new InternalBatchAsyncFetch<>( - logger, - "batch_shards_store", - shardIdsMap, - this.batchStoreAction, - batch.getBatchId() - )); - } + ShardsBatch shardsBatch = new ShardsBatch(batchUUId, addToCurrentBatch, primary); + // add the batch to list of current batches + addBatch(shardsBatch, primary); + addShardsIdsToLookup(addToCurrentBatch.keySet(), batchUUId, primary); addToCurrentBatch.clear(); batchSize = MAX_BATCH_SIZE; } } } + private void addBatch(ShardsBatch shardsBatch, boolean primary) { + ConcurrentMap batches = primary ? batchIdToStartedShardBatch : batchIdToStoreShardBatch; + if (batches.containsKey(shardsBatch.getBatchId())) { + throw new IllegalStateException("Batch already exists. BatchId = " + shardsBatch.getBatchId()); + } + batches.put(shardsBatch.getBatchId(), shardsBatch); + } + + private void addShardsIdsToLookup(Set shards, String batchId, boolean primary) { + ConcurrentMap lookupMap = primary ? startedShardBatchLookup : storeShardBatchLookup; + shards.forEach(shardId -> { + if(lookupMap.containsKey(shardId)){ + throw new IllegalStateException("Shard is already Batched. ShardId = " + shardId + "Batch Id="+ lookupMap.get(shardId)); + } + lookupMap.put(shardId, batchId); + }); + } + + /** + * Safely remove a shard from the appropriate batch. + * If the shard is not in a batch, this is a no-op. + * Cleans the batch if it is empty after removing the shard. + * This method should be called when removing the shard from the batch instead {@link ShardsBatch#removeFromBatch(ShardRouting)} + * so that we can clean up the batch if it is empty and release the fetching resources + * @param shardRouting + */ + private void safelyRemoveShardFromBatch(ShardRouting shardRouting) { + String batchId = shardRouting.primary() ? startedShardBatchLookup.get(shardRouting.shardId()) : storeShardBatchLookup.get(shardRouting.shardId()); + if (batchId == null) { + return; + } + ConcurrentMap batches = shardRouting.primary() ? batchIdToStartedShardBatch : batchIdToStoreShardBatch; + ShardsBatch batch = batches.get(batchId); + batch.removeFromBatch(shardRouting); + // remove the batch if it is empty + if (batch.getBatchedShards().isEmpty()) { + Releasables.close(batch.getAsyncFetcher()); + batches.remove(batchId); + } + } + // allow for testing infra to change shard allocators implementation protected static void innerAllocatedUnassigned( RoutingAllocation allocation, @@ -502,31 +526,120 @@ protected boolean hasInitiatedFetching(ShardRouting shard) { } } - + /** + * Holds information about a batch of shards to be allocated. + * Async fetcher is used to fetch the data for the batch. + */ private class ShardsBatch { - private final String uuid; + private final String batchId; + boolean primary; + + private final AsyncBatchShardFetch asyncBatch; + + private final Map batchInfo; + + public ShardsBatch(String batchId, Map shardsWithInfo, boolean primary) { + this.batchId = batchId; + this.batchInfo = new HashMap<>(shardsWithInfo); + // create a ShardId -> customDataPath map for async fetch + Map shardIdsMap = batchInfo.entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().getCustomDataPath() + )); + this.primary = primary; + if (primary) { + asyncBatch = new InternalBatchAsyncFetch<>( + logger, + "batch_shards_started", + shardIdsMap, + batchStartedAction, + batchId); + } else { + asyncBatch = new InternalBatchAsyncFetch<>( + logger, + "batch_shards_started", + shardIdsMap, + batchStoreAction, + batchId); - public Map getShardsToCustomDataPathMap() { - return shardsToCustomDataPathMap; + } } - private Map shardsToCustomDataPathMap; - private ShardsBatch(String uuid, Map shardsToCustomDataPathMap) { - this.uuid = uuid; - this.shardsToCustomDataPathMap = shardsToCustomDataPathMap; + private void removeFromBatch(ShardRouting shard) { + + batchInfo.remove(shard.shardId()); + asyncBatch.shardsToCustomDataPathMap.remove(shard.shardId()); + assert shard.primary() == primary : "Illegal call to delete shard from batch"; + // remove from lookup + if (this.primary) { + startedShardBatchLookup.remove(shard.shardId()); + } else { + storeShardBatchLookup.remove(shard.shardId()); + } + // assert that fetcher and shards are the same as batched shards + assert batchInfo.size() == asyncBatch.shardsToCustomDataPathMap.size() : "Shards size is not equal to fetcher size"; } - void removeFromBatch(ShardRouting shard) { - shardsToCustomDataPathMap.remove(shard); + + Set getBatchedShardRoutings() { + return batchInfo.values().stream().map(ShardBatchEntry::getShardRouting).collect(Collectors.toSet()); } - Set getBatchedShards() { - return shardsToCustomDataPathMap.keySet(); + Set getBatchedShards() { + return batchInfo.keySet(); } public String getBatchId() { - return uuid; + return batchId; + } + + AsyncBatchShardFetch getAsyncFetcher() { + return asyncBatch; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || o instanceof ShardsBatch == false) { + return false; + } + ShardsBatch shardsBatch = (ShardsBatch) o; + return batchId.equals(shardsBatch.getBatchId()) && batchInfo.keySet().equals(shardsBatch.getBatchedShards()); + } + + @Override + public int hashCode() { + return Objects.hash(batchId); + } + + @Override + public String toString() { + return "batchId: " + batchId; } + } + /** + * Holds information about a shard to be allocated in a batch. + */ + private class ShardBatchEntry { + + private final String customDataPath; + private final ShardRouting shardRouting; + + public ShardBatchEntry(String customDataPath, ShardRouting shardRouting) { + this.customDataPath = customDataPath; + this.shardRouting = shardRouting; + } + + public ShardRouting getShardRouting() { + return shardRouting; + } + + public String getCustomDataPath() { + return customDataPath; + } + } }