Skip to content

Commit

Permalink
Introduce pagnator in describe_images boto3 calls (#6418)
Browse files Browse the repository at this point in the history
* Introduce pagnator in describe_images boto3 calls.
  • Loading branch information
hehe7318 committed Aug 28, 2024
1 parent c84baa2 commit 8edfdfa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ CHANGELOG
- Fix cluster deletion failure when placement group is enabled.
- Fix issue with login nodes being marked unhealthy when restricting SSH access.
- Fix `retrieve_supported_regions` so that it can get the correct S3 url.
- Fix `describe_images` so that it uses pagination.

3.10.1
------
Expand Down
28 changes: 20 additions & 8 deletions cli/src/pcluster/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,33 @@ def get_subnet_cidr(self, subnet_id):
return subnets[0].get("CidrBlock")
raise AWSClientError(function_name="describe_subnets", message=f"Subnet {subnet_id} not found")

@AWSExceptionHandler.handle_client_exception
@Cache.cached
def _describe_images_with_pagination(self, **kwargs):
"""Use paginator to describe images and handle pagination."""
paginator = self._client.get_paginator("describe_images")
page_iterator = paginator.paginate(**kwargs)
images = []
for page in page_iterator:
images.extend(page["Images"])
return images

@AWSExceptionHandler.handle_client_exception
@Cache.cached
def describe_image(self, ami_id):
"""Describe image by image id, return an object of ImageInfo."""
result = self._client.describe_images(ImageIds=[ami_id])
if result.get("Images"):
return ImageInfo(result.get("Images")[0])
images = self._describe_images_with_pagination(ImageIds=[ami_id])
if images:
return ImageInfo(images[0])
raise AWSClientError(function_name="describe_images", message=f"Image {ami_id} not found")

@AWSExceptionHandler.handle_client_exception
@Cache.cached
def describe_images(self, ami_ids, filters, owners):
"""Return a list of objects of ImageInfo."""
result = self._client.describe_images(ImageIds=ami_ids, Filters=filters, Owners=owners)
if result.get("Images"):
return [ImageInfo(image) for image in result.get("Images")]
images = self._describe_images_with_pagination(ImageIds=ami_ids, Filters=filters, Owners=owners)
if images:
return [ImageInfo(image) for image in images]
raise ImageNotFoundError(function_name="describe_images")

def image_exists(self, image_id: str):
Expand Down Expand Up @@ -301,7 +312,7 @@ def get_official_image_id(self, os, architecture, filters=None):

filters = [{"Name": "name", "Values": ["{0}*".format(self._get_official_image_name_prefix(os, architecture))]}]
filters.extend([{"Name": f"tag:{tag.key}", "Values": [tag.value]} for tag in tags])
images = self._client.describe_images(Owners=[owner], Filters=filters, IncludeDeprecated=True).get("Images")
images = self._describe_images_with_pagination(Owners=[owner], Filters=filters, IncludeDeprecated=True)
if not images:
raise AWSClientError(function_name="describe_images", message="Cannot find official ParallelCluster AMI")
return self._find_valid_official_image(images).get("ImageId")
Expand All @@ -313,10 +324,11 @@ def get_official_images(self, os=None, architecture=None):
owners = ["amazon"]
name = f"{self._get_official_image_name_prefix(os, architecture)}*"
filters = [{"Name": "name", "Values": [name]}]
images = self._describe_images_with_pagination(Owners=owners, Filters=filters, IncludeDeprecated=True)
return [
ImageInfo(self._find_valid_official_image(images_os_arch))
for _, images_os_arch in itertools.groupby(
self._client.describe_images(Owners=owners, Filters=filters, IncludeDeprecated=True).get("Images"),
images,
key=lambda image: f'{self.extract_os_from_official_image_name(image["Name"])}-{image["Architecture"]}',
)
]
Expand Down

0 comments on commit 8edfdfa

Please sign in to comment.