Skip to content

Commit

Permalink
feature: Add optional CodeArtifact login to FrameworkProcessing job s…
Browse files Browse the repository at this point in the history
…cript
  • Loading branch information
akuma12 authored and goelakash committed Oct 9, 2023
1 parent a9ac311 commit 61190de
Showing 1 changed file with 91 additions and 8 deletions.
99 changes: 91 additions & 8 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from textwrap import dedent
from typing import Dict, List, Optional, Union
from copy import copy
import re

import attr

Expand Down Expand Up @@ -1659,6 +1660,7 @@ def run( # type: ignore[override]
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
kms_key: Optional[str] = None,
codeartifact_repo_arn: Optional[str] = None,
):
"""Runs a processing job.
Expand Down Expand Up @@ -1759,12 +1761,21 @@ def run( # type: ignore[override]
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
logged into before installing dependencies (default: None).
Returns:
None or pipeline step arguments in case the Processor instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
code, source_dir, dependencies, git_config, job_name, inputs, kms_key
code,
source_dir,
dependencies,
git_config,
job_name,
inputs,
kms_key,
codeartifact_repo_arn,
)

# Submit a processing job.
Expand All @@ -1781,7 +1792,15 @@ def run( # type: ignore[override]
)

def _pack_and_upload_code(
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
self,
code,
source_dir,
dependencies,
git_config,
job_name,
inputs,
kms_key=None,
codeartifact_repo_arn=None,
):
"""Pack local code bundle and upload to Amazon S3."""
if code.startswith("s3://"):
Expand Down Expand Up @@ -1822,12 +1841,65 @@ def _pack_and_upload_code(
script = estimator.uploaded_code.script_name
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
s3_runproc_sh = self._create_and_upload_runproc(
script, evaluated_kms_key, entrypoint_s3_uri
script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn
)

return s3_runproc_sh, inputs, job_name

def _generate_framework_script(self, user_script: str) -> str:
def _get_codeartifact_index(self, codeartifact_repo_arn: str):
"""
Build the authenticated codeartifact index url based on the arn provided
via codeartifact_repo_arn property following the form
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
:return: authenticated codeartifact index url
"""

arn_regex = (
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
)
m = re.match(arn_regex, codeartifact_repo_arn)
if not m:
raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
domain = m.group("domain")
owner = m.group("account")
repository = m.group("repository")
region = m.group("region")

logger.info(
"configuring pip to use codeartifact "
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
domain,
owner,
repository,
region,
)
try:
client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
token = auth_token_response["authorizationToken"]
endpoint_response = client.get_repository_endpoint(
domain=domain, domainOwner=owner, repository=repository, format="pypi"
)
unauthenticated_index = endpoint_response["repositoryEndpoint"]
return re.sub(
"https://",
"https://aws:{}@".format(token),
re.sub(
"{}/?$".format(repository),
"{}/simple/".format(repository),
unauthenticated_index,
),
)
except Exception:
logger.error("failed to configure pip to use codeartifact")
raise Exception("failed to configure pip to use codeartifact")

def _generate_framework_script(
self, user_script: str, codeartifact_repo_arn: str = None
) -> str:
"""Generate the framework entrypoint file (as text) for a processing job.
This script implements the "framework" functionality for setting up your code:
Expand All @@ -1838,7 +1910,15 @@ def _generate_framework_script(self, user_script: str) -> str:
Args:
user_script (str): Relative path to ```code``` in the source bundle
- e.g. 'process.py'.
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
logged into before installing dependencies (default: None).
"""
if codeartifact_repo_arn:
index = self._get_codeartifact_index(codeartifact_repo_arn)
index_option = "-i {}".format(index)
else:
index_option = ""

return dedent(
"""\
#!/bin/bash
Expand All @@ -1853,12 +1933,13 @@ def _generate_framework_script(self, user_script: str) -> str:
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing
pip install -r requirements.txt
pip install -r requirements.txt {index_option}
fi
{entry_point_command} {entry_point} "$@"
"""
).format(
index_option=index_option,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)
Expand Down Expand Up @@ -1934,7 +2015,9 @@ def _set_entrypoint(self, command, user_script_name):
)
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
def _create_and_upload_runproc(
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None
):
"""Create runproc shell script and upload to S3 bucket.
If leveraging a pipeline session with optimized S3 artifact paths,
Expand All @@ -1950,7 +2033,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
from sagemaker.workflow.utilities import _pipeline_config, hash_object

if _pipeline_config and _pipeline_config.pipeline_name:
runproc_file_str = self._generate_framework_script(user_script)
runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn)
runproc_file_hash = hash_object(runproc_file_str)
s3_uri = s3.s3_path_join(
"s3://",
Expand All @@ -1969,7 +2052,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
)
else:
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
self._generate_framework_script(user_script),
self._generate_framework_script(user_script, codeartifact_repo_arn),
desired_s3_uri=entrypoint_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
Expand Down

0 comments on commit 61190de

Please sign in to comment.