Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Add optional CodeArtifact login to FrameworkProcessing job script #4145

Merged
merged 19 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 86 additions & 7 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from textwrap import dedent
from typing import Dict, List, Optional, Union
from copy import copy
import re

import attr

Expand Down Expand Up @@ -1658,6 +1659,7 @@ def run( # type: ignore[override]
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
kms_key: Optional[str] = None,
codeartifact_repo_arn: Optional[str] = None,
):
"""Runs a processing job.

Expand Down Expand Up @@ -1758,12 +1760,21 @@ def run( # type: ignore[override]
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
logged into before installing dependencies (default: None).
Returns:
None or pipeline step arguments in case the Processor instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
code, source_dir, dependencies, git_config, job_name, inputs, kms_key
code,
source_dir,
dependencies,
git_config,
job_name,
inputs,
kms_key,
codeartifact_repo_arn,
)

# Submit a processing job.
Expand All @@ -1780,7 +1791,15 @@ def run( # type: ignore[override]
)

def _pack_and_upload_code(
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
self,
code,
source_dir,
dependencies,
git_config,
job_name,
inputs,
kms_key=None,
codeartifact_repo_arn=None,
):
"""Pack local code bundle and upload to Amazon S3."""
if code.startswith("s3://"):
Expand Down Expand Up @@ -1821,12 +1840,53 @@ def _pack_and_upload_code(
script = estimator.uploaded_code.script_name
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
s3_runproc_sh = self._create_and_upload_runproc(
script, evaluated_kms_key, entrypoint_s3_uri
script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn
)

return s3_runproc_sh, inputs, job_name

def _generate_framework_script(self, user_script: str) -> str:
def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str:
"""Build an AWS CLI CodeArtifact command to configure pip.

The codeartifact_repo_arn property must follow the form
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies

Args:
codeartifact_repo_arn: arn of the codeartifact repository
Returns:
codeartifact command string
"""

arn_regex = (
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
)
m = re.match(arn_regex, codeartifact_repo_arn)
if not m:
raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
domain = m.group("domain")
owner = m.group("account")
repository = m.group("repository")
region = m.group("region")

logger.info(
"configuring pip to use codeartifact "
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
domain,
owner,
repository,
region,
)

return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}".format( # noqa: E501 pylint: disable=line-too-long
domain, owner, repository, region
)

def _generate_framework_script(
self, user_script: str, codeartifact_repo_arn: str = None
) -> str:
"""Generate the framework entrypoint file (as text) for a processing job.

This script implements the "framework" functionality for setting up your code:
Expand All @@ -1837,7 +1897,16 @@ def _generate_framework_script(self, user_script: str) -> str:
Args:
user_script (str): Relative path to ```code``` in the source bundle
- e.g. 'process.py'.
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
logged into before installing dependencies (default: None).
"""
if codeartifact_repo_arn:
codeartifact_login_command = self._get_codeartifact_command(codeartifact_repo_arn)
else:
codeartifact_login_command = (
"echo 'CodeArtifact repository not specified. Skipping login.'"
)

return dedent(
"""\
#!/bin/bash
Expand All @@ -1849,6 +1918,13 @@ def _generate_framework_script(self, user_script: str) -> str:
set -e

if [[ -f 'requirements.txt' ]]; then
# Optionally log into CodeArtifact
if ! hash aws 2>/dev/null; then
echo "AWS CLI is not installed. Skipping CodeArtifact login."
else
{codeartifact_login_command}
fi

# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

Expand All @@ -1858,6 +1934,7 @@ def _generate_framework_script(self, user_script: str) -> str:
{entry_point_command} {entry_point} "$@"
"""
).format(
codeartifact_login_command=codeartifact_login_command,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)
Expand Down Expand Up @@ -1933,7 +2010,9 @@ def _set_entrypoint(self, command, user_script_name):
)
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
def _create_and_upload_runproc(
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None
):
"""Create runproc shell script and upload to S3 bucket.

If leveraging a pipeline session with optimized S3 artifact paths,
Expand All @@ -1949,7 +2028,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
from sagemaker.workflow.utilities import _pipeline_config, hash_object

if _pipeline_config and _pipeline_config.pipeline_name:
runproc_file_str = self._generate_framework_script(user_script)
runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn)
runproc_file_hash = hash_object(runproc_file_str)
s3_uri = s3.s3_path_join(
"s3://",
Expand All @@ -1968,7 +2047,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
)
else:
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
self._generate_framework_script(user_script),
self._generate_framework_script(user_script, codeartifact_repo_arn),
desired_s3_uri=entrypoint_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
Expand Down
132 changes: 132 additions & 0 deletions tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import absolute_import

import copy
from textwrap import dedent

import pytest
from mock import Mock, patch, MagicMock
Expand Down Expand Up @@ -1102,6 +1103,137 @@ def test_pyspark_processor_configuration_path_pipeline_config(
)


@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
def test_get_codeartifact_command(pipeline_session):
codeartifact_repo_arn = (
"arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
)

processor = PyTorchProcessor(
role=ROLE,
instance_type="ml.m4.xlarge",
framework_version="2.0.1",
py_version="py310",
instance_count=1,
sagemaker_session=pipeline_session,
)

codeartifact_command = processor._get_codeartifact_command(
codeartifact_repo_arn=codeartifact_repo_arn
)

assert (
codeartifact_command
== "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 # pylint: disable=line-too-long
)


@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
def test_get_codeartifact_command_bad_repo_arn(pipeline_session):
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"

processor = PyTorchProcessor(
role=ROLE,
instance_type="ml.m4.xlarge",
framework_version="2.0.1",
py_version="py310",
instance_count=1,
sagemaker_session=pipeline_session,
)

with pytest.raises(ValueError):
processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn)


@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
def test_generate_framework_script(pipeline_session):
processor = PyTorchProcessor(
role=ROLE,
instance_type="ml.m4.xlarge",
framework_version="2.0.1",
py_version="py310",
instance_count=1,
sagemaker_session=pipeline_session,
)

framework_script = processor._generate_framework_script(user_script="process.py")

assert framework_script == dedent(
"""\
#!/bin/bash

cd /opt/ml/processing/input/code/
tar -xzf sourcedir.tar.gz

# Exit on any error. SageMaker uses error code to mark failed job.
set -e

if [[ -f 'requirements.txt' ]]; then
# Optionally log into CodeArtifact
if ! hash aws 2>/dev/null; then
echo "AWS CLI is not installed. Skipping CodeArtifact login."
else
echo 'CodeArtifact repository not specified. Skipping login.'
fi

# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi

python process.py "$@"
"""
)


@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
def test_generate_framework_script_with_codeartifact(pipeline_session):
processor = PyTorchProcessor(
role=ROLE,
instance_type="ml.m4.xlarge",
framework_version="2.0.1",
py_version="py310",
instance_count=1,
sagemaker_session=pipeline_session,
)

framework_script = processor._generate_framework_script(
user_script="process.py",
codeartifact_repo_arn=(
"arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
),
)

assert framework_script == dedent(
"""\
#!/bin/bash

cd /opt/ml/processing/input/code/
tar -xzf sourcedir.tar.gz

# Exit on any error. SageMaker uses error code to mark failed job.
set -e

if [[ -f 'requirements.txt' ]]; then
# Optionally log into CodeArtifact
if ! hash aws 2>/dev/null; then
echo "AWS CLI is not installed. Skipping CodeArtifact login."
else
aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2
fi

# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi

python process.py "$@"
""" # noqa: E501 # pylint: disable=line-too-long
)


def _get_script_processor(sagemaker_session):
return ScriptProcessor(
role=ROLE,
Expand Down