From 53a07bdad8bd2df703b62769983bfe83a3c43b90 Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Wed, 27 Sep 2023 13:07:32 -0500 Subject: [PATCH 1/9] feature: Add optional CodeArtifact login to FrameworkProcessing job script --- src/sagemaker/processing.py | 99 ++++++++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 7b16e3cba3..2152bb316e 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -24,6 +24,7 @@ from textwrap import dedent from typing import Dict, List, Optional, Union from copy import copy +import re import attr @@ -1658,6 +1659,7 @@ def run( # type: ignore[override] job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, + codeartifact_repo_arn: Optional[str] = None, ): """Runs a processing job. @@ -1758,12 +1760,21 @@ def run( # type: ignore[override] However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( - code, source_dir, dependencies, git_config, job_name, inputs, kms_key + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key, + codeartifact_repo_arn, ) # Submit a processing job. @@ -1780,7 +1791,15 @@ def run( # type: ignore[override] ) def _pack_and_upload_code( - self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None + self, + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key=None, + codeartifact_repo_arn=None, ): """Pack local code bundle and upload to Amazon S3.""" if code.startswith("s3://"): @@ -1821,12 +1840,65 @@ def _pack_and_upload_code( script = estimator.uploaded_code.script_name evaluated_kms_key = kms_key if kms_key else self.output_kms_key s3_runproc_sh = self._create_and_upload_runproc( - script, evaluated_kms_key, entrypoint_s3_uri + script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn ) return s3_runproc_sh, inputs, job_name - def _generate_framework_script(self, user_script: str) -> str: + def _get_codeartifact_index(self, codeartifact_repo_arn: str): + """ + Build the authenticated codeartifact index url based on the arn provided + via codeartifact_repo_arn property following the form + # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + :return: authenticated codeartifact index url + """ + + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, codeartifact_repo_arn) + if not m: + raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + try: + client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region) + auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) + token = auth_token_response["authorizationToken"] + endpoint_response = client.get_repository_endpoint( + domain=domain, domainOwner=owner, repository=repository, format="pypi" + ) + unauthenticated_index = endpoint_response["repositoryEndpoint"] + return re.sub( + "https://", + "https://aws:{}@".format(token), + re.sub( + "{}/?$".format(repository), + "{}/simple/".format(repository), + unauthenticated_index, + ), + ) + except Exception: + logger.error("failed to configure pip to use codeartifact") + raise Exception("failed to configure pip to use codeartifact") + + def _generate_framework_script( + self, user_script: str, codeartifact_repo_arn: str = None + ) -> str: """Generate the framework entrypoint file (as text) for a processing job. This script implements the "framework" functionality for setting up your code: @@ -1837,7 +1909,15 @@ def _generate_framework_script(self, user_script: str) -> str: Args: user_script (str): Relative path to ```code``` in the source bundle - e.g. 'process.py'. + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). """ + if codeartifact_repo_arn: + index = self._get_codeartifact_index(codeartifact_repo_arn) + index_option = "-i {}".format(index) + else: + index_option = "" + return dedent( """\ #!/bin/bash @@ -1852,12 +1932,13 @@ def _generate_framework_script(self, user_script: str) -> str: # Some py3 containers has typing, which may breaks pip install pip uninstall --yes typing - pip install -r requirements.txt + pip install -r requirements.txt {index_option} fi {entry_point_command} {entry_point} "$@" """ ).format( + index_option=index_option, entry_point_command=" ".join(self.command), entry_point=user_script, ) @@ -1933,7 +2014,9 @@ def _set_entrypoint(self, command, user_script_name): ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] - def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + def _create_and_upload_runproc( + self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None + ): """Create runproc shell script and upload to S3 bucket. If leveraging a pipeline session with optimized S3 artifact paths, @@ -1949,7 +2032,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): from sagemaker.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script(user_script) + runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", @@ -1968,7 +2051,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): ) else: s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(user_script), + self._generate_framework_script(user_script, codeartifact_repo_arn), desired_s3_uri=entrypoint_s3_uri, kms_key=kms_key, sagemaker_session=self.sagemaker_session, From 4969bfb7db84f7d2160a5014156e8710f642f738 Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Wed, 11 Oct 2023 17:12:39 -0500 Subject: [PATCH 2/9] Add unit test for _get_codeartifact_index --- src/sagemaker/processing.py | 26 +++++--- tests/unit/test_processing.py | 118 ++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 2152bb316e..321abb1c13 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -22,7 +22,7 @@ import pathlib import logging from textwrap import dedent -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from copy import copy import re @@ -1845,14 +1845,18 @@ def _pack_and_upload_code( return s3_runproc_sh, inputs, job_name - def _get_codeartifact_index(self, codeartifact_repo_arn: str): + def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None): """ Build the authenticated codeartifact index url based on the arn provided via codeartifact_repo_arn property following the form # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies - :return: authenticated codeartifact index url + Args: + codeartifact_repo_arn: arn of the codeartifact repository + codeartifact_client: boto3 client for codeartifact (used for testing) + Returns: + authenticated codeartifact index url """ arn_regex = ( @@ -1861,7 +1865,7 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str): ) m = re.match(arn_regex, codeartifact_repo_arn) if not m: - raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) + raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) domain = m.group("domain") owner = m.group("account") repository = m.group("repository") @@ -1876,10 +1880,12 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str): region, ) try: - client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region) - auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) + if not codeartifact_client: + codeartifact_client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region) + + auth_token_response = codeartifact_client.get_authorization_token(domain=domain, domainOwner=owner) token = auth_token_response["authorizationToken"] - endpoint_response = client.get_repository_endpoint( + endpoint_response = codeartifact_client.get_repository_endpoint( domain=domain, domainOwner=owner, repository=repository, format="pypi" ) unauthenticated_index = endpoint_response["repositoryEndpoint"] @@ -1892,9 +1898,9 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str): unauthenticated_index, ), ) - except Exception: - logger.error("failed to configure pip to use codeartifact") - raise Exception("failed to configure pip to use codeartifact") + except Exception as e: + logger.error("failed to configure pip to use codeartifact: %s", e, exc_info=True) + raise RuntimeError("failed to configure pip to use codeartifact") def _generate_framework_script( self, user_script: str, codeartifact_repo_arn: str = None diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 93e3d91f87..fb55e2fe4c 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,7 +13,10 @@ from __future__ import absolute_import import copy +import datetime +import boto3 +from botocore.stub import Stubber import pytest from mock import Mock, patch, MagicMock from packaging import version @@ -1102,6 +1105,121 @@ def test_pyspark_processor_configuration_path_pipeline_config( ) +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_index(pipeline_session): + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" + + client = boto3.client('codeartifact', region_name=REGION) + stubber = Stubber(client) + + get_auth_token_response = { + "authorizationToken": "mocked_token", + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) + } + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} + stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) + + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} + repo_endpoint_expected_params = { + "domain": "test-domain", + "domainOwner": "012345678901", + "repository": "test-repository", + "format": "pypi" + } + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + with stubber: + codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) + + assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}" + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_index_bad_repo_arn(pipeline_session): + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain" + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" + + client = boto3.client('codeartifact', region_name=REGION) + stubber = Stubber(client) + + get_auth_token_response = { + "authorizationToken": "mocked_token", + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) + } + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} + stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) + + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} + repo_endpoint_expected_params = { + "domain": "test-domain", + "domainOwner": "012345678901", + "repository": "test-repository", + "format": "pypi" + } + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + with stubber: + with pytest.raises(ValueError): + processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_index_client_error(pipeline_session): + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" + + client = boto3.client('codeartifact', region_name=REGION) + stubber = Stubber(client) + + get_auth_token_response = { + "authorizationToken": "mocked_token", + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) + } + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} + stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params) + + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} + repo_endpoint_expected_params = { + "domain": "test-domain", + "domainOwner": "012345678901", + "repository": "test-repository", + "format": "pypi" + } + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + with stubber: + with pytest.raises(RuntimeError): + processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) + + def _get_script_processor(sagemaker_session): return ScriptProcessor( role=ROLE, From fa37d4ce8b63eb0dfe064ccc5e3935d7e887ca67 Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Wed, 11 Oct 2023 17:25:51 -0500 Subject: [PATCH 3/9] Fixed docstring --- src/sagemaker/processing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 321abb1c13..b4a063ba2a 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1846,12 +1846,13 @@ def _pack_and_upload_code( return s3_runproc_sh, inputs, job_name def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None): - """ - Build the authenticated codeartifact index url based on the arn provided - via codeartifact_repo_arn property following the form + """Build an authenticated codeartifact index url based on the arn provided. + + The codeartifact_repo_arn property must follow the form # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + Args: codeartifact_repo_arn: arn of the codeartifact repository codeartifact_client: boto3 client for codeartifact (used for testing) From fee0a830e942a636d20d8141beb8c8c411dbd9ed Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Fri, 15 Mar 2024 14:46:21 -0500 Subject: [PATCH 4/9] Convert CodeArtifact integration to simply generate an AWS CLI command to log into CodeArtifact --- src/sagemaker/processing.py | 62 ++++++-------- tests/unit/test_processing.py | 156 +++++++++++++++++----------------- 2 files changed, 106 insertions(+), 112 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index b4a063ba2a..740b319a16 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -22,7 +22,7 @@ import pathlib import logging from textwrap import dedent -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from copy import copy import re @@ -1845,19 +1845,18 @@ def _pack_and_upload_code( return s3_runproc_sh, inputs, job_name - def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None): - """Build an authenticated codeartifact index url based on the arn provided. + def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str: + """Build an AWS CLI CodeArtifact command to configure pip. The codeartifact_repo_arn property must follow the form # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies - + Args: codeartifact_repo_arn: arn of the codeartifact repository - codeartifact_client: boto3 client for codeartifact (used for testing) Returns: - authenticated codeartifact index url + codeartifact command string """ arn_regex = ( @@ -1866,7 +1865,9 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_clien ) m = re.match(arn_regex, codeartifact_repo_arn) if not m: - raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) + raise ValueError( + "invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn) + ) domain = m.group("domain") owner = m.group("account") repository = m.group("repository") @@ -1880,28 +1881,8 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_clien repository, region, ) - try: - if not codeartifact_client: - codeartifact_client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region) - - auth_token_response = codeartifact_client.get_authorization_token(domain=domain, domainOwner=owner) - token = auth_token_response["authorizationToken"] - endpoint_response = codeartifact_client.get_repository_endpoint( - domain=domain, domainOwner=owner, repository=repository, format="pypi" - ) - unauthenticated_index = endpoint_response["repositoryEndpoint"] - return re.sub( - "https://", - "https://aws:{}@".format(token), - re.sub( - "{}/?$".format(repository), - "{}/simple/".format(repository), - unauthenticated_index, - ), - ) - except Exception as e: - logger.error("failed to configure pip to use codeartifact: %s", e, exc_info=True) - raise RuntimeError("failed to configure pip to use codeartifact") + + return f"aws codeartifact login --tool pip --domain {domain} --domain-owner {owner} --repository {repository} --region {region}" # pylint: disable=line-too-long def _generate_framework_script( self, user_script: str, codeartifact_repo_arn: str = None @@ -1920,10 +1901,12 @@ def _generate_framework_script( logged into before installing dependencies (default: None). """ if codeartifact_repo_arn: - index = self._get_codeartifact_index(codeartifact_repo_arn) - index_option = "-i {}".format(index) + codeartifact_login_command = self._get_codeartifact_command( + codeartifact_repo_arn + ) else: - index_option = "" + codeartifact_login_command = \ + "echo 'CodeArtifact repository not specified. Skipping login.'" return dedent( """\ @@ -1936,16 +1919,23 @@ def _generate_framework_script( set -e if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + {codeartifact_login_command} + fi + # Some py3 containers has typing, which may breaks pip install pip uninstall --yes typing - pip install -r requirements.txt {index_option} + pip install -r requirements.txt fi {entry_point_command} {entry_point} "$@" """ ).format( - index_option=index_option, + codeartifact_login_command=codeartifact_login_command, entry_point_command=" ".join(self.command), entry_point=user_script, ) @@ -2039,7 +2029,9 @@ def _create_and_upload_runproc( from sagemaker.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn) + runproc_file_str = self._generate_framework_script( + user_script, codeartifact_repo_arn + ) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index fb55e2fe4c..95f7b1dca6 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,13 +13,11 @@ from __future__ import absolute_import import copy -import datetime -import boto3 -from botocore.stub import Stubber import pytest from mock import Mock, patch, MagicMock from packaging import version +from textwrap import dedent from sagemaker import LocalSession from sagemaker.dataset_definition.inputs import ( @@ -1106,28 +1104,8 @@ def test_pyspark_processor_configuration_path_pipeline_config( @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_get_codeartifact_index(pipeline_session): +def test_get_codeartifact_command(pipeline_session): codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" - codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" - - client = boto3.client('codeartifact', region_name=REGION) - stubber = Stubber(client) - - get_auth_token_response = { - "authorizationToken": "mocked_token", - "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) - } - auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} - stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) - - get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} - repo_endpoint_expected_params = { - "domain": "test-domain", - "domainOwner": "012345678901", - "repository": "test-repository", - "format": "pypi" - } - stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) processor = PyTorchProcessor( role=ROLE, @@ -1138,35 +1116,14 @@ def test_get_codeartifact_index(pipeline_session): sagemaker_session=pipeline_session, ) - with stubber: - codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) + codeartifact_command = processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) - assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}" + assert codeartifact_command == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_get_codeartifact_index_bad_repo_arn(pipeline_session): +def test_get_codeartifact_command_bad_repo_arn(pipeline_session): codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain" - codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" - - client = boto3.client('codeartifact', region_name=REGION) - stubber = Stubber(client) - - get_auth_token_response = { - "authorizationToken": "mocked_token", - "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) - } - auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} - stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) - - get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} - repo_endpoint_expected_params = { - "domain": "test-domain", - "domainOwner": "012345678901", - "repository": "test-repository", - "format": "pypi" - } - stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) processor = PyTorchProcessor( role=ROLE, @@ -1177,35 +1134,52 @@ def test_get_codeartifact_index_bad_repo_arn(pipeline_session): sagemaker_session=pipeline_session, ) - with stubber: - with pytest.raises(ValueError): - processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) - + with pytest.raises(ValueError): + processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_get_codeartifact_index_client_error(pipeline_session): - codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" - codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" - - client = boto3.client('codeartifact', region_name=REGION) - stubber = Stubber(client) - - get_auth_token_response = { - "authorizationToken": "mocked_token", - "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) - } - auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} - stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params) - - get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} - repo_endpoint_expected_params = { - "domain": "test-domain", - "domainOwner": "012345678901", - "repository": "test-repository", - "format": "pypi" - } - stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) +def test_generate_framework_script(pipeline_session): + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + framework_script = processor._generate_framework_script(user_script="process.py") + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + echo 'CodeArtifact repository not specified. Skipping login.' + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ + ) + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_generate_framework_script_with_codeartifact(pipeline_session): processor = PyTorchProcessor( role=ROLE, instance_type="ml.m4.xlarge", @@ -1215,10 +1189,38 @@ def test_get_codeartifact_index_client_error(pipeline_session): sagemaker_session=pipeline_session, ) - with stubber: - with pytest.raises(RuntimeError): - processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) + framework_script = processor._generate_framework_script( + user_script="process.py", + codeartifact_repo_arn="arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ) + + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ + ) def _get_script_processor(sagemaker_session): return ScriptProcessor( From 5a0971651a4851ba1b463a2809d28b40e118fa6b Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Fri, 15 Mar 2024 15:36:10 -0500 Subject: [PATCH 5/9] Fix lint issues --- src/sagemaker/processing.py | 4 +++- tests/unit/test_processing.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 740b319a16..6c4d532aad 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1882,7 +1882,9 @@ def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str: region, ) - return f"aws codeartifact login --tool pip --domain {domain} --domain-owner {owner} --repository {repository} --region {region}" # pylint: disable=line-too-long + return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}".format( # noqa: E501 pylint: disable=line-too-long + domain, owner, repository, region + ) def _generate_framework_script( self, user_script: str, codeartifact_repo_arn: str = None diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 95f7b1dca6..9b5761eda0 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -1117,8 +1117,8 @@ def test_get_codeartifact_command(pipeline_session): ) codeartifact_command = processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) - - assert codeartifact_command == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" + + assert codeartifact_command == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @@ -1137,6 +1137,7 @@ def test_get_codeartifact_command_bad_repo_arn(pipeline_session): with pytest.raises(ValueError): processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) + @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) def test_generate_framework_script(pipeline_session): processor = PyTorchProcessor( @@ -1177,7 +1178,8 @@ def test_generate_framework_script(pipeline_session): python process.py "$@" """ ) - + + @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) def test_generate_framework_script_with_codeartifact(pipeline_session): processor = PyTorchProcessor( @@ -1209,7 +1211,7 @@ def test_generate_framework_script_with_codeartifact(pipeline_session): if ! hash aws 2>/dev/null; then echo "AWS CLI is not installed. Skipping CodeArtifact login." else - "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" + aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2 fi # Some py3 containers has typing, which may breaks pip install @@ -1219,9 +1221,10 @@ def test_generate_framework_script_with_codeartifact(pipeline_session): fi python process.py "$@" - """ + """ # noqa: E501 ) + def _get_script_processor(sagemaker_session): return ScriptProcessor( role=ROLE, From 56a4ff96f75990336d34da47e2524379317a154f Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Fri, 15 Mar 2024 16:15:00 -0500 Subject: [PATCH 6/9] More lint fixes --- src/sagemaker/processing.py | 15 +++++---------- tests/unit/test_processing.py | 17 +++++++++++++---- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 6c4d532aad..36cb920dde 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1865,9 +1865,7 @@ def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str: ) m = re.match(arn_regex, codeartifact_repo_arn) if not m: - raise ValueError( - "invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn) - ) + raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) domain = m.group("domain") owner = m.group("account") repository = m.group("repository") @@ -1903,12 +1901,11 @@ def _generate_framework_script( logged into before installing dependencies (default: None). """ if codeartifact_repo_arn: - codeartifact_login_command = self._get_codeartifact_command( - codeartifact_repo_arn - ) + codeartifact_login_command = self._get_codeartifact_command(codeartifact_repo_arn) else: - codeartifact_login_command = \ + codeartifact_login_command = ( "echo 'CodeArtifact repository not specified. Skipping login.'" + ) return dedent( """\ @@ -2031,9 +2028,7 @@ def _create_and_upload_runproc( from sagemaker.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script( - user_script, codeartifact_repo_arn - ) + runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 9b5761eda0..8f43ecd49f 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -1105,7 +1105,9 @@ def test_pyspark_processor_configuration_path_pipeline_config( @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) def test_get_codeartifact_command(pipeline_session): - codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + codeartifact_repo_arn = ( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ) processor = PyTorchProcessor( role=ROLE, @@ -1116,9 +1118,14 @@ def test_get_codeartifact_command(pipeline_session): sagemaker_session=pipeline_session, ) - codeartifact_command = processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) + codeartifact_command = processor._get_codeartifact_command( + codeartifact_repo_arn=codeartifact_repo_arn + ) - assert codeartifact_command == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 + assert ( + codeartifact_command + == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" + ) # noqa: E501 @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @@ -1193,7 +1200,9 @@ def test_generate_framework_script_with_codeartifact(pipeline_session): framework_script = processor._generate_framework_script( user_script="process.py", - codeartifact_repo_arn="arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + codeartifact_repo_arn=( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ), ) assert framework_script == dedent( From e12eab58400da93be87e4dde2eafdeef4d3031fa Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Mon, 18 Mar 2024 12:20:52 -0500 Subject: [PATCH 7/9] Lint fix --- tests/unit/test_processing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 8f43ecd49f..0abbecca85 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -1123,9 +1123,11 @@ def test_get_codeartifact_command(pipeline_session): ) assert ( - codeartifact_command - == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" - ) # noqa: E501 + codeartifact_command == ( + "aws codeartifact login --tool pip --domain test-domain ", + "--domain-owner 012345678901 --repository test-repository --region us-west-2" + ) + ) @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) From 01a82dc24aa982ab83878f86d2ed03df883a29fd Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Tue, 19 Mar 2024 10:31:47 -0500 Subject: [PATCH 8/9] Yet Another Lint Fix --- tests/unit/test_processing.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 0abbecca85..fb97fb4bf4 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,11 +13,11 @@ from __future__ import absolute_import import copy +from textwrap import dedent import pytest from mock import Mock, patch, MagicMock from packaging import version -from textwrap import dedent from sagemaker import LocalSession from sagemaker.dataset_definition.inputs import ( @@ -1122,12 +1122,8 @@ def test_get_codeartifact_command(pipeline_session): codeartifact_repo_arn=codeartifact_repo_arn ) - assert ( - codeartifact_command == ( - "aws codeartifact login --tool pip --domain test-domain ", - "--domain-owner 012345678901 --repository test-repository --region us-west-2" - ) - ) + assert codeartifact_command == \ + "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) From d68b19c2f5f9e70f1dc0553a9d4965e49cd23397 Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Tue, 19 Mar 2024 10:54:16 -0500 Subject: [PATCH 9/9] Black fix --- tests/unit/test_processing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index fb97fb4bf4..06d2cde02e 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -1122,8 +1122,10 @@ def test_get_codeartifact_command(pipeline_session): codeartifact_repo_arn=codeartifact_repo_arn ) - assert codeartifact_command == \ - "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 + assert ( + codeartifact_command + == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 # pylint: disable=line-too-long + ) @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @@ -1228,7 +1230,7 @@ def test_generate_framework_script_with_codeartifact(pipeline_session): fi python process.py "$@" - """ # noqa: E501 + """ # noqa: E501 # pylint: disable=line-too-long )