diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index f44ac76c77..d3b8c129e0 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -9,7 +9,7 @@ coverage>=5.2, <6.2 mock==4.0.3 contextlib2==21.6.0 awslogs==0.14.0 -black==22.3.0 +black==24.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file apache-airflow==2.8.4 diff --git a/src/sagemaker/amazon/record_pb2.py b/src/sagemaker/amazon/record_pb2.py index d06b38663b..9dd6893958 100644 --- a/src/sagemaker/amazon/record_pb2.py +++ b/src/sagemaker/amazon/record_pb2.py @@ -31,7 +31,7 @@ (_message.Message,), { "DESCRIPTOR": _FLOAT32TENSOR, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Float32Tensor) }, ) @@ -42,7 +42,7 @@ (_message.Message,), { "DESCRIPTOR": _FLOAT64TENSOR, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Float64Tensor) }, ) @@ -53,7 +53,7 @@ (_message.Message,), { "DESCRIPTOR": _INT32TENSOR, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Int32Tensor) }, ) @@ -64,7 +64,7 @@ (_message.Message,), { "DESCRIPTOR": _BYTES, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Bytes) }, ) @@ -75,7 +75,7 @@ (_message.Message,), { "DESCRIPTOR": _VALUE, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Value) }, ) @@ -90,7 +90,7 @@ (_message.Message,), { "DESCRIPTOR": _RECORD_FEATURESENTRY, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Record.FeaturesEntry) }, ), @@ -99,12 +99,12 @@ (_message.Message,), { "DESCRIPTOR": _RECORD_LABELENTRY, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Record.LabelEntry) }, ), "DESCRIPTOR": _RECORD, - "__module__": "record_pb2" + "__module__": "record_pb2", # @@protoc_insertion_point(class_scope:aialgs.data.Record) }, ) diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 1413f3aa29..bb4059c03a 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -930,9 +930,9 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True): auto_ml_model_deploy_config = {} if auto_ml.auto_generate_endpoint_name is not None: - auto_ml_model_deploy_config[ - "AutoGenerateEndpointName" - ] = auto_ml.auto_generate_endpoint_name + auto_ml_model_deploy_config["AutoGenerateEndpointName"] = ( + auto_ml.auto_generate_endpoint_name + ) if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None: auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name @@ -1034,9 +1034,9 @@ def _prepare_auto_ml_stop_condition( if max_candidates is not None: stopping_condition["MaxCandidates"] = max_candidates if max_runtime_per_training_job_in_seconds is not None: - stopping_condition[ - "MaxRuntimePerTrainingJobInSeconds" - ] = max_runtime_per_training_job_in_seconds + stopping_condition["MaxRuntimePerTrainingJobInSeconds"] = ( + max_runtime_per_training_job_in_seconds + ) if total_job_runtime_in_seconds is not None: stopping_condition["MaxAutoMLJobRuntimeInSeconds"] = total_job_runtime_in_seconds diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index 8b34f54a95..0819e5384e 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -1446,9 +1446,9 @@ def _load_config(cls, inputs, auto_ml, expand_role=True): auto_ml_model_deploy_config = {} if auto_ml.auto_generate_endpoint_name is not None: - auto_ml_model_deploy_config[ - "AutoGenerateEndpointName" - ] = auto_ml.auto_generate_endpoint_name + auto_ml_model_deploy_config["AutoGenerateEndpointName"] = ( + auto_ml.auto_generate_endpoint_name + ) if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None: auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name diff --git a/src/sagemaker/collection.py b/src/sagemaker/collection.py index 7633085506..653500fee5 100644 --- a/src/sagemaker/collection.py +++ b/src/sagemaker/collection.py @@ -377,9 +377,11 @@ def _convert_group_resource_response( { "Name": collection_name, "Arn": collection_arn, - "Type": resource_group["Identifier"]["ResourceType"] - if is_model_group - else "Collection", + "Type": ( + resource_group["Identifier"]["ResourceType"] + if is_model_group + else "Collection" + ), } ) return collection_details diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 15766d6936..8ce3b643bf 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -162,19 +162,19 @@ def _to_request_dict(self): profiler_config_request["DisableProfiler"] = self.disable_profiler if self.system_monitor_interval_millis is not None: - profiler_config_request[ - "ProfilingIntervalInMilliseconds" - ] = self.system_monitor_interval_millis + profiler_config_request["ProfilingIntervalInMilliseconds"] = ( + self.system_monitor_interval_millis + ) if self.framework_profile_params is not None: - profiler_config_request[ - "ProfilingParameters" - ] = self.framework_profile_params.profiling_parameters + profiler_config_request["ProfilingParameters"] = ( + self.framework_profile_params.profiling_parameters + ) if self.profile_params is not None: - profiler_config_request[ - "ProfilingParameters" - ] = self.profile_params.profiling_parameters + profiler_config_request["ProfilingParameters"] = ( + self.profile_params.profiling_parameters + ) return profiler_config_request diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 8308215e81..efbb44460c 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -213,9 +213,7 @@ def _create_estimator( vpc_config: Optional[ Dict[ str, - List[ - str, - ], + List[str], ] ] = None, volume_kms_key=None, @@ -820,9 +818,9 @@ def _get_container_env(self): logger.warning("Ignoring invalid container log level: %s", self.container_log_level) return self.env - self.env[ - "SERVING_OPTS" - ] = f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"' + self.env["SERVING_OPTS"] = ( + f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"' + ) return self.env diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 8e03cbf132..066846564e 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2539,9 +2539,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): # which is parsed in execution time # This does not check config because the EstimatorBase constuctor already did that check if estimator.encrypt_inter_container_traffic: - train_args[ - "encrypt_inter_container_traffic" - ] = estimator.encrypt_inter_container_traffic + train_args["encrypt_inter_container_traffic"] = ( + estimator.encrypt_inter_container_traffic + ) if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator): train_args["algorithm_arn"] = estimator.algorithm_arn @@ -2556,9 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict() if estimator.tensorboard_output_config: - train_args[ - "tensorboard_output_config" - ] = estimator.tensorboard_output_config._to_request_dict() + train_args["tensorboard_output_config"] = ( + estimator.tensorboard_output_config._to_request_dict() + ) cls._add_spot_checkpoint_args(local_mode, estimator, train_args) diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 8fbffe3667..33f2f0bbdc 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -220,9 +220,9 @@ def __init__( trial_component_name=self._trial_component.trial_component_name, sagemaker_session=sagemaker_session, artifact_bucket=artifact_bucket, - artifact_prefix=_DEFAULT_ARTIFACT_PREFIX - if artifact_prefix is None - else artifact_prefix, + artifact_prefix=( + _DEFAULT_ARTIFACT_PREFIX if artifact_prefix is None else artifact_prefix + ), ) self._lineage_artifact_tracker = _LineageArtifactTracker( trial_component_arn=self._trial_component.trial_component_arn, diff --git a/src/sagemaker/explainer/explainer_config.py b/src/sagemaker/explainer/explainer_config.py index 6a174b27d5..72c226b15d 100644 --- a/src/sagemaker/explainer/explainer_config.py +++ b/src/sagemaker/explainer/explainer_config.py @@ -37,8 +37,8 @@ def _to_request_dict(self): request_dict = {} if self.clarify_explainer_config: - request_dict[ - "ClarifyExplainerConfig" - ] = self.clarify_explainer_config._to_request_dict() + request_dict["ClarifyExplainerConfig"] = ( + self.clarify_explainer_config._to_request_dict() + ) return request_dict diff --git a/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py index 6bba9e5dd6..3801f57be1 100644 --- a/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py +++ b/src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -115,19 +115,19 @@ class FeatureProcessorLineageHandler: def create_lineage(self, tags: Optional[List[Dict[str, str]]] = None) -> None: """Create and Update Feature Processor Lineage""" - input_feature_group_contexts: List[ - FeatureGroupContexts - ] = self._retrieve_input_feature_group_contexts() + input_feature_group_contexts: List[FeatureGroupContexts] = ( + self._retrieve_input_feature_group_contexts() + ) output_feature_group_contexts: FeatureGroupContexts = ( self._retrieve_output_feature_group_contexts() ) input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() - transformation_code_artifact: Optional[ - Artifact - ] = S3LineageEntityHandler.create_transformation_code_artifact( - transformation_code=self.transformation_code, - pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), - sagemaker_session=self.sagemaker_session, + transformation_code_artifact: Optional[Artifact] = ( + S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=self.transformation_code, + pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) ) if transformation_code_artifact is not None: logger.info("Created Transformation Code Artifact: %s", transformation_code_artifact) @@ -362,40 +362,40 @@ def _update_pipeline_lineage( current_pipeline_version_context: Context = self._get_pipeline_version_context( last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] ) - upstream_feature_group_associations: Iterator[ - AssociationSummary - ] = LineageAssociationHandler.list_upstream_associations( - # pylint: disable=no-member - entity_arn=current_pipeline_version_context.context_arn, - source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, - sagemaker_session=self.sagemaker_session, + upstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) ) - upstream_raw_data_associations: Iterator[ - AssociationSummary - ] = LineageAssociationHandler.list_upstream_associations( - # pylint: disable=no-member - entity_arn=current_pipeline_version_context.context_arn, - source_type=DATA_SET, - sagemaker_session=self.sagemaker_session, + upstream_raw_data_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=DATA_SET, + sagemaker_session=self.sagemaker_session, + ) ) - upstream_transformation_code: Iterator[ - AssociationSummary - ] = LineageAssociationHandler.list_upstream_associations( - # pylint: disable=no-member - entity_arn=current_pipeline_version_context.context_arn, - source_type=TRANSFORMATION_CODE, - sagemaker_session=self.sagemaker_session, + upstream_transformation_code: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=TRANSFORMATION_CODE, + sagemaker_session=self.sagemaker_session, + ) ) - downstream_feature_group_associations: Iterator[ - AssociationSummary - ] = LineageAssociationHandler.list_downstream_associations( - # pylint: disable=no-member - entity_arn=current_pipeline_version_context.context_arn, - destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, - sagemaker_session=self.sagemaker_session, + downstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_downstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) ) is_upstream_feature_group_equal: bool = self._compare_upstream_feature_groups( @@ -598,9 +598,9 @@ def _update_last_transformation_code( last_transformation_code_artifact.properties["state"] == TRANSFORMATION_CODE_STATUS_ACTIVE ): - last_transformation_code_artifact.properties[ - "state" - ] = TRANSFORMATION_CODE_STATUS_INACTIVE + last_transformation_code_artifact.properties["state"] = ( + TRANSFORMATION_CODE_STATUS_INACTIVE + ) last_transformation_code_artifact.properties["exclusive_end_date"] = self.pipeline[ LAST_MODIFIED_TIME ].strftime("%s") diff --git a/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py b/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py index 5c2b59c228..5d8b32adb8 100644 --- a/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py +++ b/src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py @@ -172,9 +172,9 @@ def retrieve_pipeline_schedule_artifact( sagemaker_session=sagemaker_session, ) pipeline_schedule_artifact.properties["pipeline_name"] = pipeline_schedule.pipeline_name - pipeline_schedule_artifact.properties[ - "schedule_expression" - ] = pipeline_schedule.schedule_expression + pipeline_schedule_artifact.properties["schedule_expression"] = ( + pipeline_schedule.schedule_expression + ) pipeline_schedule_artifact.properties["state"] = pipeline_schedule.state pipeline_schedule_artifact.properties["start_date"] = pipeline_schedule.start_date pipeline_schedule_artifact.save() diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index fa5ae8900b..c28c27ed4e 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -112,16 +112,16 @@ def _retrieve_default_environment_variables( default_environment_variables.update(instance_specific_environment_variables) - retrieve_gated_env_var_for_instance_type: Callable[ - [str], Optional[str] - ] = lambda instance_type: _retrieve_gated_model_uri_env_var_value( - model_id=model_id, - model_version=model_version, - region=region, - tolerate_vulnerable_model=tolerate_vulnerable_model, - tolerate_deprecated_model=tolerate_deprecated_model, - sagemaker_session=sagemaker_session, - instance_type=instance_type, + retrieve_gated_env_var_for_instance_type: Callable[[str], Optional[str]] = ( + lambda instance_type: _retrieve_gated_model_uri_env_var_value( + model_id=model_id, + model_version=model_version, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + instance_type=instance_type, + ) ) gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type( @@ -213,10 +213,10 @@ def _retrieve_gated_model_uri_env_var_value( sagemaker_session=sagemaker_session, ) - s3_key: Optional[ - str - ] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301 - instance_type + s3_key: Optional[str] = ( + model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301 + instance_type + ) ) if s3_key is None: return None diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 9c1ea79504..242f49658a 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -243,9 +243,11 @@ "", (logging.StreamHandler,), { - "emit": lambda self, *args, **kwargs: logging.StreamHandler.emit(self, *args, **kwargs) - if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) - else None + "emit": lambda self, *args, **kwargs: ( + logging.StreamHandler.emit(self, *args, **kwargs) + if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) + else None + ) }, )() ) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index fc5113315d..3490db7b32 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -71,10 +71,8 @@ class ProprietaryModelFilterIdentifiers(str, Enum): } -_PAD_ALPHABETIC_OPERATOR = ( - lambda operator: f" {operator} " - if any(character.isalpha() for character in operator) - else operator +_PAD_ALPHABETIC_OPERATOR = lambda operator: ( # noqa E731 + f" {operator} " if any(character.isalpha() for character in operator) else operator ) ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = ( diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index e724bbd1e7..732493ce3b 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -436,19 +436,19 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, manifest_specs_cached_values[val] = getattr(model_manifest, val) if is_task_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.TASK - ] = extract_framework_task_model(model_manifest.model_id)[1] + manifest_specs_cached_values[SpecialSupportedFilterKeys.TASK] = ( + extract_framework_task_model(model_manifest.model_id)[1] + ) if is_framework_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.FRAMEWORK - ] = extract_framework_task_model(model_manifest.model_id)[0] + manifest_specs_cached_values[SpecialSupportedFilterKeys.FRAMEWORK] = ( + extract_framework_task_model(model_manifest.model_id)[0] + ) if is_model_type_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.MODEL_TYPE - ] = extract_model_type_filter_representation(model_manifest.spec_key) + manifest_specs_cached_values[SpecialSupportedFilterKeys.MODEL_TYPE] = ( + extract_model_type_filter_representation(model_manifest.spec_key) + ) if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 02ff39641a..63cfac0939 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -78,9 +78,9 @@ def get_jumpstart_gated_content_bucket( unavailable in that region. """ - old_gated_content_bucket: Optional[ - str - ] = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() + old_gated_content_bucket: Optional[str] = ( + accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() + ) info_logs: List[str] = [] @@ -129,9 +129,9 @@ def get_jumpstart_content_bucket( ValueError: If JumpStart is not launched in ``region``. """ - old_content_bucket: Optional[ - str - ] = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket() + old_content_bucket: Optional[str] = ( + accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket() + ) info_logs: List[str] = [] @@ -175,9 +175,9 @@ def get_formatted_manifest( manifest_dict = {} for header in manifest: header_obj = JumpStartModelHeader(header) - manifest_dict[ - JumpStartVersionedModelId(header_obj.model_id, header_obj.version) - ] = header_obj + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) return manifest_dict diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 02ade9bd9a..2ce37f68bd 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -832,7 +832,7 @@ def _initialize_and_validate_parameters(self, overridden_parameters): merged_parameters = {} default_parameters = {parameter.name: parameter for parameter in self.pipeline.parameters} if overridden_parameters is not None: - for (param_name, param_value) in overridden_parameters.items(): + for param_name, param_value in overridden_parameters.items(): if param_name not in default_parameters: error_msg = self._construct_validation_exception_message( "Unknown parameter '{}'".format(param_name) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 377bdcac85..32437a59c3 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -294,9 +294,9 @@ def train(self, input_data_config, output_data_config, hyperparameters, environm } training_env_vars.update(environment) if self.sagemaker_session.s3_resource is not None: - training_env_vars[ - S3_ENDPOINT_URL_ENV_NAME - ] = self.sagemaker_session.s3_resource.meta.client._endpoint.host + training_env_vars[S3_ENDPOINT_URL_ENV_NAME] = ( + self.sagemaker_session.s3_resource.meta.client._endpoint.host + ) compose_data = self._generate_compose_file( "train", additional_volumes=volumes, additional_env_vars=training_env_vars diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 7d48850077..36a848aa52 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -472,9 +472,9 @@ def update_pipeline( raise ClientError(error_response, "update_pipeline") LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description LocalSagemakerClient._pipelines[pipeline.name].pipeline = pipeline - LocalSagemakerClient._pipelines[ - pipeline.name - ].last_modified_time = datetime.now().timestamp() + LocalSagemakerClient._pipelines[pipeline.name].last_modified_time = ( + datetime.now().timestamp() + ) return {"PipelineArn": pipeline.name} def describe_pipeline(self, PipelineName): diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index 9e97dd2059..8d613d6469 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -307,10 +307,10 @@ def execute(self): "ProcessingOutputConfig" in job_describe_response and "Outputs" in job_describe_response["ProcessingOutputConfig"] ): - job_describe_response["ProcessingOutputConfig"][ - "Outputs" - ] = self._convert_list_to_dict( - job_describe_response, "ProcessingOutputConfig.Outputs", "OutputName" + job_describe_response["ProcessingOutputConfig"]["Outputs"] = ( + self._convert_list_to_dict( + job_describe_response, "ProcessingOutputConfig.Outputs", "OutputName" + ) ) if "ProcessingInputs" in job_describe_response: job_describe_response["ProcessingInputs"] = self._convert_list_to_dict( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index c5f8b86f60..fd21b6342e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -667,9 +667,9 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, image_config=self.image_config, - accept_eula=accept_eula - if accept_eula is not None - else getattr(self, "accept_eula", None), + accept_eula=( + accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) + ), ) def is_repack(self) -> bool: @@ -1009,9 +1009,9 @@ def _compilation_job_config( """Placeholder Docstring""" input_model_config = { "S3Uri": self.model_data, - "DataInputConfig": json.dumps(input_shape) - if isinstance(input_shape, dict) - else input_shape, + "DataInputConfig": ( + json.dumps(input_shape) if isinstance(input_shape, dict) else input_shape + ), "Framework": framework.upper(), } @@ -1449,9 +1449,9 @@ def deploy( ): tags = add_jumpstart_uri_tags( tags=tags, - inference_model_uri=self.model_data - if isinstance(self.model_data, (str, dict)) - else None, + inference_model_uri=( + self.model_data if isinstance(self.model_data, (str, dict)) else None + ), inference_script_uri=self.source_dir, ) @@ -1566,13 +1566,13 @@ def deploy( # [TODO]: Refactor to a module startup_parameters = {} if model_data_download_timeout: - startup_parameters[ - "ModelDataDownloadTimeoutInSeconds" - ] = model_data_download_timeout + startup_parameters["ModelDataDownloadTimeoutInSeconds"] = ( + model_data_download_timeout + ) if container_startup_health_check_timeout: - startup_parameters[ - "ContainerStartupHealthCheckTimeoutInSeconds" - ] = container_startup_health_check_timeout + startup_parameters["ContainerStartupHealthCheckTimeoutInSeconds"] = ( + container_startup_health_check_timeout + ) inference_component_spec = { "ModelName": self.name, diff --git a/src/sagemaker/model_card/model_card.py b/src/sagemaker/model_card/model_card.py index e6f32eaf8f..33af98723f 100644 --- a/src/sagemaker/model_card/model_card.py +++ b/src/sagemaker/model_card/model_card.py @@ -507,9 +507,11 @@ def from_model_package_arn(cls, model_package_arn: str, sagemaker_session: Sessi "SourceAlgorithms" ] source_algorithms = [ - SourceAlgorithm(sa["AlgorithmName"], sa["ModelDataUrl"]) - if "ModelDataUrl" in sa - else SourceAlgorithm(sa["AlgorithmName"]) + ( + SourceAlgorithm(sa["AlgorithmName"], sa["ModelDataUrl"]) + if "ModelDataUrl" in sa + else SourceAlgorithm(sa["AlgorithmName"]) + ) for sa in source_algorithms_response ] @@ -948,18 +950,22 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", ** "training_environment": Environment( container_image=[training_job_data["AlgorithmSpecification"]["TrainingImage"]] ), - "training_metrics": [ - TrainingMetric(i["MetricName"], i["Value"]) - for i in training_job_data["FinalMetricDataList"] - ] - if "FinalMetricDataList" in training_job_data - else [], - "hyper_parameters": [ - HyperParameter(key, value) - for key, value in training_job_data["HyperParameters"].items() - ] - if "HyperParameters" in training_job_data - else [], + "training_metrics": ( + [ + TrainingMetric(i["MetricName"], i["Value"]) + for i in training_job_data["FinalMetricDataList"] + ] + if "FinalMetricDataList" in training_job_data + else [] + ), + "hyper_parameters": ( + [ + HyperParameter(key, value) + for key, value in training_job_data["HyperParameters"].items() + ] + if "HyperParameters" in training_job_data + else [] + ), } kwargs.update({"training_job_details": TrainingJobDetails(**job)}) instance = cls(**kwargs) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index c313efcf5e..436377fea5 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -2739,9 +2739,9 @@ def _build_create_data_quality_job_definition_request( app_specification["ImageUri"] = image_uri if post_analytics_processor_script_s3_uri: - app_specification[ - "PostAnalyticsProcessorSourceUri" - ] = post_analytics_processor_script_s3_uri + app_specification["PostAnalyticsProcessorSourceUri"] = ( + post_analytics_processor_script_s3_uri + ) if record_preprocessor_script_s3_uri: app_specification["RecordPreprocessorSourceUri"] = record_preprocessor_script_s3_uri @@ -3519,9 +3519,9 @@ def _build_create_model_quality_job_definition_request( ) if post_analytics_processor_script_s3_uri: - app_specification[ - "PostAnalyticsProcessorSourceUri" - ] = post_analytics_processor_script_s3_uri + app_specification["PostAnalyticsProcessorSourceUri"] = ( + post_analytics_processor_script_s3_uri + ) if record_preprocessor_script_s3_uri: app_specification["RecordPreprocessorSourceUri"] = record_preprocessor_script_s3_uri @@ -4107,9 +4107,9 @@ def _to_request_dict(self): if self.probability_attribute is not None: batch_transform_input_data["ProbabilityAttribute"] = self.probability_attribute if self.probability_threshold_attribute is not None: - batch_transform_input_data[ - "ProbabilityThresholdAttribute" - ] = self.probability_threshold_attribute + batch_transform_input_data["ProbabilityThresholdAttribute"] = ( + self.probability_threshold_attribute + ) if self.exclude_features_attribute is not None: batch_transform_input_data["ExcludeFeaturesAttribute"] = self.exclude_features_attribute diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index 2942e71062..9beb6544a6 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -62,9 +62,9 @@ def _to_request_dict(self): network_config_request = {"EnableNetworkIsolation": enable_network_isolation} if self.encrypt_inter_container_traffic is not None: - network_config_request[ - "EnableInterContainerTrafficEncryption" - ] = self.encrypt_inter_container_traffic + network_config_request["EnableInterContainerTrafficEncryption"] = ( + self.encrypt_inter_container_traffic + ) if self.security_group_ids is not None or self.subnets is not None: network_config_request["VpcConfig"] = {} diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index de33f61b82..06d2ecfcde 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -76,16 +76,16 @@ def retrieve_all_examples( "Must specify JumpStart `model_id` and `model_version` when retrieving payloads." ) - unserialized_payload_dict: Optional[ - Dict[str, JumpStartSerializablePayload] - ] = artifacts._retrieve_example_payloads( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - sagemaker_session=sagemaker_session, - model_type=model_type, + unserialized_payload_dict: Optional[Dict[str, JumpStartSerializablePayload]] = ( + artifacts._retrieve_example_payloads( + model_id, + model_version, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + model_type=model_type, + ) ) if unserialized_payload_dict is None: diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index cdf9b141b3..ef70b93599 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -313,12 +313,14 @@ def check_failure_file(): failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) failure_response = self.predictor._handle_response(response=failure_object) - raise AsyncInferenceModelError( - message=failure_response - ) if failure_file_found.is_set() else PollingTimeoutError( - message="Inference could still be running", - output_path=output_path, - seconds=waiter_config.delay * waiter_config.max_attempts, + raise ( + AsyncInferenceModelError(message=failure_response) + if failure_file_found.is_set() + else PollingTimeoutError( + message="Inference could still be running", + output_path=output_path, + seconds=waiter_config.delay * waiter_config.max_attempts, + ) ) def update_endpoint( diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index a854de9135..5814ee45ff 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -863,9 +863,9 @@ def compile( request_dict["EnableNetworkIsolation"] = job_settings.enable_network_isolation if job_settings.encrypt_inter_container_traffic is not None: - request_dict[ - "EnableInterContainerTrafficEncryption" - ] = job_settings.encrypt_inter_container_traffic + request_dict["EnableInterContainerTrafficEncryption"] = ( + job_settings.encrypt_inter_container_traffic + ) if job_settings.vpc_config: request_dict["VpcConfig"] = job_settings.vpc_config @@ -1080,9 +1080,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): if user_workspace_s3uri: input_data_config.append( dict( - ChannelName=REMOTE_FUNCTION_WORKSPACE - if not step_compilation_context - else step_compilation_context.pipeline_build_time, + ChannelName=( + REMOTE_FUNCTION_WORKSPACE + if not step_compilation_context + else step_compilation_context.pipeline_build_time + ), DataSource={ "S3DataSource": { "S3Uri": user_workspace_s3uri, diff --git a/src/sagemaker/serve/__init__.py b/src/sagemaker/serve/__init__.py index bb5cafc8dd..886e59d4f1 100644 --- a/src/sagemaker/serve/__init__.py +++ b/src/sagemaker/serve/__init__.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import import logging diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index 2a0d4892e4..3fd1816d0e 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import import io import logging @@ -230,18 +231,26 @@ def __repr__(self): def generate_marshalling_map(self) -> dict: """Generate marshalling map for the schema builder""" return { - "input_serializer": self.input_serializer.__class__.__name__ - if hasattr(self, "input_serializer") - else None, - "output_serializer": self.output_serializer.__class__.__name__ - if hasattr(self, "output_serializer") - else None, - "input_deserializer": self._input_deserializer.__class__.__name__ - if hasattr(self, "_input_deserializer") - else None, - "output_deserializer": self._output_deserializer.__class__.__name__ - if hasattr(self, "_output_deserializer") - else None, + "input_serializer": ( + self.input_serializer.__class__.__name__ + if hasattr(self, "input_serializer") + else None + ), + "output_serializer": ( + self.output_serializer.__class__.__name__ + if hasattr(self, "output_serializer") + else None + ), + "input_deserializer": ( + self._input_deserializer.__class__.__name__ + if hasattr(self, "_input_deserializer") + else None + ), + "output_deserializer": ( + self._output_deserializer.__class__.__name__ + if hasattr(self, "_output_deserializer") + else None + ), "custom_input_translator": hasattr(self, "custom_input_translator"), "custom_output_translator": hasattr(self, "custom_output_translator"), } diff --git a/src/sagemaker/serve/builder/triton_schema_builder.py b/src/sagemaker/serve/builder/triton_schema_builder.py index dd51061353..3015dc6e19 100644 --- a/src/sagemaker/serve/builder/triton_schema_builder.py +++ b/src/sagemaker/serve/builder/triton_schema_builder.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import from sagemaker.serve.marshalling.triton_translator import ( diff --git a/src/sagemaker/serve/detector/image_detector.py b/src/sagemaker/serve/detector/image_detector.py index 15be8eae89..63831f5950 100644 --- a/src/sagemaker/serve/detector/image_detector.py +++ b/src/sagemaker/serve/detector/image_detector.py @@ -1,4 +1,5 @@ """Detects the image to deploy model""" + from __future__ import absolute_import from typing import Tuple, List import platform diff --git a/src/sagemaker/serve/detector/pickle_dependencies.py b/src/sagemaker/serve/detector/pickle_dependencies.py index 8ba95a3cc0..5a1cd43869 100644 --- a/src/sagemaker/serve/detector/pickle_dependencies.py +++ b/src/sagemaker/serve/detector/pickle_dependencies.py @@ -1,4 +1,5 @@ """Load a pickled object to detect the dependencies it requires""" + from __future__ import absolute_import from pathlib import Path from typing import List diff --git a/src/sagemaker/serve/detector/pickler.py b/src/sagemaker/serve/detector/pickler.py index 3adddddc21..8ea2514c73 100644 --- a/src/sagemaker/serve/detector/pickler.py +++ b/src/sagemaker/serve/detector/pickler.py @@ -1,4 +1,5 @@ """Save the object using cloudpickle""" + from __future__ import absolute_import from typing import Any from pathlib import Path diff --git a/src/sagemaker/serve/marshalling/custom_payload_translator.py b/src/sagemaker/serve/marshalling/custom_payload_translator.py index d5562bba51..154d47c80e 100644 --- a/src/sagemaker/serve/marshalling/custom_payload_translator.py +++ b/src/sagemaker/serve/marshalling/custom_payload_translator.py @@ -1,4 +1,5 @@ """Defines CustomPayloadTranslator class that holds custom serialization/deserialization code""" + from __future__ import absolute_import import abc from typing import IO diff --git a/src/sagemaker/serve/marshalling/triton_translator.py b/src/sagemaker/serve/marshalling/triton_translator.py index 034f536fa6..f7a941c2f2 100644 --- a/src/sagemaker/serve/marshalling/triton_translator.py +++ b/src/sagemaker/serve/marshalling/triton_translator.py @@ -1,4 +1,5 @@ """Implements class converts data from and to np.ndarray""" + from __future__ import absolute_import import logging diff --git a/src/sagemaker/serve/mode/function_pointers.py b/src/sagemaker/serve/mode/function_pointers.py index df2f5cc6cd..ba4103eae8 100644 --- a/src/sagemaker/serve/mode/function_pointers.py +++ b/src/sagemaker/serve/mode/function_pointers.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import from enum import Enum diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py index aba6aeb25d..362a3804de 100644 --- a/src/sagemaker/serve/mode/local_container_mode.py +++ b/src/sagemaker/serve/mode/local_container_mode.py @@ -1,4 +1,5 @@ """Module that defines the LocalContainerMode class""" + from __future__ import absolute_import from pathlib import Path import logging diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 8f22c4aa1c..0fdc425b31 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import from pathlib import Path diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 3664abf56a..8b152e5b81 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -1,4 +1,5 @@ """Module for Local DJL Serving""" + from __future__ import absolute_import import requests diff --git a/src/sagemaker/serve/model_server/djl_serving/utils.py b/src/sagemaker/serve/model_server/djl_serving/utils.py index 1b016a6eae..03719542d2 100644 --- a/src/sagemaker/serve/model_server/djl_serving/utils.py +++ b/src/sagemaker/serve/model_server/djl_serving/utils.py @@ -1,4 +1,5 @@ """DJL ModelBuilder Utils""" + from __future__ import absolute_import from urllib.error import HTTPError import math diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index 36e14a56af..b78e01f5c3 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -1,4 +1,5 @@ """Module for the MultiModel Local and Remote servers""" + from __future__ import absolute_import import requests diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index d518ccee11..ef39e890c8 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -1,4 +1,5 @@ """Module for Local TGI Serving""" + from __future__ import absolute_import import requests diff --git a/src/sagemaker/serve/model_server/tgi/utils.py b/src/sagemaker/serve/model_server/tgi/utils.py index ef3a8b7525..05c6f33c5a 100644 --- a/src/sagemaker/serve/model_server/tgi/utils.py +++ b/src/sagemaker/serve/model_server/tgi/utils.py @@ -1,4 +1,5 @@ """TGI ModelBuilder Utils""" + from __future__ import absolute_import from typing import Dict diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 2161763fc9..2675f6ea6a 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -1,4 +1,5 @@ """This module is for SageMaker inference.py.""" + from __future__ import absolute_import import os import io diff --git a/src/sagemaker/serve/model_server/torchserve/server.py b/src/sagemaker/serve/model_server/torchserve/server.py index 8d403b8f91..5aef136355 100644 --- a/src/sagemaker/serve/model_server/torchserve/server.py +++ b/src/sagemaker/serve/model_server/torchserve/server.py @@ -1,4 +1,5 @@ """Module for Local Torch Server""" + from __future__ import absolute_import import requests diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 1b6f7d0e58..4e82ec66b2 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -1,4 +1,5 @@ """This module is for SageMaker inference.py.""" + from __future__ import absolute_import import os import io diff --git a/src/sagemaker/serve/model_server/triton/config_template.py b/src/sagemaker/serve/model_server/triton/config_template.py index 227cd5c7c1..e351049b43 100644 --- a/src/sagemaker/serve/model_server/triton/config_template.py +++ b/src/sagemaker/serve/model_server/triton/config_template.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import CONFIG_TEMPLATE = """name: \"model\" diff --git a/src/sagemaker/serve/model_server/triton/model.py b/src/sagemaker/serve/model_server/triton/model.py index 0b8e4a27a5..5d8e9b0a2d 100644 --- a/src/sagemaker/serve/model_server/triton/model.py +++ b/src/sagemaker/serve/model_server/triton/model.py @@ -1,4 +1,5 @@ """This module is for Triton Python backend.""" + from __future__ import absolute_import import os import logging diff --git a/src/sagemaker/serve/model_server/triton/server.py b/src/sagemaker/serve/model_server/triton/server.py index 2ebec25406..62dfb4759a 100644 --- a/src/sagemaker/serve/model_server/triton/server.py +++ b/src/sagemaker/serve/model_server/triton/server.py @@ -1,4 +1,5 @@ """Placeholder docerting""" + from __future__ import absolute_import import uuid import logging diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index d47d0cc356..ed0ec49204 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -1,4 +1,5 @@ """Placeholder docstring""" + from __future__ import absolute_import import os import logging diff --git a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/framework_handler.py b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/framework_handler.py index ac862aa507..1e909ff42b 100644 --- a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/framework_handler.py +++ b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/framework_handler.py @@ -1,4 +1,5 @@ """Experimental""" + from __future__ import absolute_import from abc import ABC, abstractmethod from typing import Type diff --git a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/pytorch_handler.py b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/pytorch_handler.py index 0a07e4d93b..7aa00f65b9 100644 --- a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/pytorch_handler.py +++ b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/pytorch_handler.py @@ -1,4 +1,5 @@ """Experimental""" + from __future__ import absolute_import import platform import logging diff --git a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/xgboost_handler.py b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/xgboost_handler.py index 6d40c138d9..19e83f58ce 100644 --- a/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/xgboost_handler.py +++ b/src/sagemaker/serve/save_retrive/version_1_0_0/save/framework/xgboost_handler.py @@ -1,4 +1,5 @@ """Experimental""" + from __future__ import absolute_import import platform import logging diff --git a/src/sagemaker/serve/save_retrive/version_1_0_0/save/utils.py b/src/sagemaker/serve/save_retrive/version_1_0_0/save/utils.py index 305b66fa92..8df442d9f8 100644 --- a/src/sagemaker/serve/save_retrive/version_1_0_0/save/utils.py +++ b/src/sagemaker/serve/save_retrive/version_1_0_0/save/utils.py @@ -1,4 +1,5 @@ """Validates the integrity of pickled file with HMAC signing.""" + from __future__ import absolute_import import secrets import hmac diff --git a/src/sagemaker/serve/spec/inference_spec.py b/src/sagemaker/serve/spec/inference_spec.py index 9ecec4fdfe..b61d7d55ea 100644 --- a/src/sagemaker/serve/spec/inference_spec.py +++ b/src/sagemaker/serve/spec/inference_spec.py @@ -1,4 +1,5 @@ """Implements class that holds custom load and invoke function of a model""" + from __future__ import absolute_import import abc diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index 8132820cc0..72b9083072 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -1,4 +1,5 @@ """Placeholder Docstring""" + from __future__ import absolute_import diff --git a/src/sagemaker/serve/utils/logging_agent.py b/src/sagemaker/serve/utils/logging_agent.py index f140a4c403..196359116e 100644 --- a/src/sagemaker/serve/utils/logging_agent.py +++ b/src/sagemaker/serve/utils/logging_agent.py @@ -1,4 +1,5 @@ """Module for pulling logs from container""" + from __future__ import absolute_import import logging from threading import Thread diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index 10fc6bb4aa..e0ff8f8ee1 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -1,4 +1,5 @@ """Defines the predictors used in local container mode""" + from __future__ import absolute_import import io from typing import Type diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index de02708278..22f3c06d47 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -1,4 +1,5 @@ """Holds mixin logic to support deployment of Model ID""" + from __future__ import absolute_import import logging diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index 60b7b5cd6f..661093f249 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -1,4 +1,5 @@ """Types used for SageMaker ModelBuilder""" + from __future__ import absolute_import from enum import Enum diff --git a/src/sagemaker/serve/utils/uploader.py b/src/sagemaker/serve/utils/uploader.py index 39459f3513..a6b07f2fa5 100644 --- a/src/sagemaker/serve/utils/uploader.py +++ b/src/sagemaker/serve/utils/uploader.py @@ -1,4 +1,5 @@ """Upload model artifacts to S3""" + from __future__ import absolute_import import logging import os diff --git a/src/sagemaker/serve/validations/check_image_and_hardware_type.py b/src/sagemaker/serve/validations/check_image_and_hardware_type.py index 935d9e185e..0046e47a80 100644 --- a/src/sagemaker/serve/validations/check_image_and_hardware_type.py +++ b/src/sagemaker/serve/validations/check_image_and_hardware_type.py @@ -1,4 +1,5 @@ """Validate if image_uri is compatible with instance_type""" + from __future__ import absolute_import import logging diff --git a/src/sagemaker/serve/validations/check_image_uri.py b/src/sagemaker/serve/validations/check_image_uri.py index 2d559bc2f3..2f50faaeed 100644 --- a/src/sagemaker/serve/validations/check_image_uri.py +++ b/src/sagemaker/serve/validations/check_image_uri.py @@ -1,4 +1,5 @@ """Validates that a given image_uri is not a 1p image.""" + from __future__ import absolute_import # Generated by running the parse_registry_accounts.py script diff --git a/src/sagemaker/serve/validations/check_integrity.py b/src/sagemaker/serve/validations/check_integrity.py index ffd4bcfcd6..01b958611c 100644 --- a/src/sagemaker/serve/validations/check_integrity.py +++ b/src/sagemaker/serve/validations/check_integrity.py @@ -1,4 +1,5 @@ """Validates the integrity of pickled file with HMAC signing.""" + from __future__ import absolute_import import secrets import hmac diff --git a/src/sagemaker/serve/validations/parse_registry_accounts.py b/src/sagemaker/serve/validations/parse_registry_accounts.py index 14261975b2..04ed6402b3 100644 --- a/src/sagemaker/serve/validations/parse_registry_accounts.py +++ b/src/sagemaker/serve/validations/parse_registry_accounts.py @@ -1,4 +1,5 @@ """Script to scrape account IDs for 1p image URIs defined in src/sagemaker/image_uri_config.""" + from __future__ import absolute_import import os diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b0ff9bcb86..9e593706c1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3470,9 +3470,9 @@ def _map_training_config( training_job_definition["EnableNetworkIsolation"] = enable_network_isolation if encrypt_inter_container_traffic: - training_job_definition[ - "EnableInterContainerTrafficEncryption" - ] = encrypt_inter_container_traffic + training_job_definition["EnableInterContainerTrafficEncryption"] = ( + encrypt_inter_container_traffic + ) if use_spot_instances: # use_spot_instances may be a Pipeline ParameterBoolean object @@ -6405,7 +6405,7 @@ def _intercept_create_request( self, request: typing.Dict, create, - func_name: str = None + func_name: str = None, # pylint: disable=unused-argument ): """This function intercepts the create job request. diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 82634071cc..5c6c64852d 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -538,14 +538,14 @@ def _prepare_history_server_env_variables(self, spark_event_logs_s3_uri): history_server_env_variables = {} if spark_event_logs_s3_uri: - history_server_env_variables[ - _HistoryServer.arg_event_logs_s3_uri - ] = spark_event_logs_s3_uri + history_server_env_variables[_HistoryServer.arg_event_logs_s3_uri] = ( + spark_event_logs_s3_uri + ) # this variable will be previously set by run() method elif self._spark_event_logs_s3_uri is not None: - history_server_env_variables[ - _HistoryServer.arg_event_logs_s3_uri - ] = self._spark_event_logs_s3_uri + history_server_env_variables[_HistoryServer.arg_event_logs_s3_uri] = ( + self._spark_event_logs_s3_uri + ) else: raise ValueError( "SPARK_EVENT_LOGS_S3_URI not present. You can specify spark_event_logs_s3_uri " @@ -557,9 +557,9 @@ def _prepare_history_server_env_variables(self, spark_event_logs_s3_uri): history_server_env_variables["AWS_REGION"] = region if self._is_notebook_instance(): - history_server_env_variables[ - _HistoryServer.arg_remote_domain_name - ] = self._get_notebook_instance_domain() + history_server_env_variables[_HistoryServer.arg_remote_domain_name] = ( + self._get_notebook_instance_domain() + ) return history_server_env_variables diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 967bff1b99..4b0f38f36f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -568,9 +568,9 @@ def to_input_req(self): ] = self.max_number_of_training_jobs_not_improving if self.target_objective_metric_value is not None: - completion_criteria_config[ - TARGET_OBJECTIVE_METRIC_VALUE - ] = self.target_objective_metric_value + completion_criteria_config[TARGET_OBJECTIVE_METRIC_VALUE] = ( + self.target_objective_metric_value + ) if self.complete_on_convergence is not None: completion_criteria_config[CONVERGENCE_DETECTED] = {} @@ -867,9 +867,11 @@ def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False) estimator_name: self._prepare_static_hyperparameters( estimator, self._hyperparameter_ranges_dict[estimator_name], - include_cls_metadata.get(estimator_name, False) - if isinstance(include_cls_metadata, dict) - else include_cls_metadata, + ( + include_cls_metadata.get(estimator_name, False) + if isinstance(include_cls_metadata, dict) + else include_cls_metadata + ), ) for (estimator_name, estimator) in self.estimator_dict.items() } @@ -887,9 +889,11 @@ def _prepare_auto_parameters_for_tuning(self): static_auto_parameters_dict = { estimator_name: self._prepare_auto_parameters( self.static_hyperparameters_dict[estimator_name], - self.hyperparameters_to_keep_static_dict.get(estimator_name, None) - if self.hyperparameters_to_keep_static_dict - else None, + ( + self.hyperparameters_to_keep_static_dict.get(estimator_name, None) + if self.hyperparameters_to_keep_static_dict + else None + ), ) for estimator_name in sorted(self.estimator_dict.keys()) } @@ -1268,10 +1272,10 @@ def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, jo objective_metric_name_dict[estimator_name] = training_details["TuningObjective"][ "MetricName" ] - hyperparameter_ranges_dict[ - estimator_name - ] = cls._prepare_parameter_ranges_from_job_description( # noqa: E501 # pylint: disable=line-too-long - training_details["HyperParameterRanges"] + hyperparameter_ranges_dict[estimator_name] = ( + cls._prepare_parameter_ranges_from_job_description( # noqa: E501 # pylint: disable=line-too-long + training_details["HyperParameterRanges"] + ) ) metric_definitions = training_details["AlgorithmSpecification"].get( @@ -2111,9 +2115,9 @@ def _add_estimator( self.objective_metric_name_dict[estimator_name] = objective_metric_name self._hyperparameter_ranges_dict[estimator_name] = hyperparameter_ranges if hyperparameters_to_keep_static is not None: - self.hyperparameters_to_keep_static_dict[ - estimator_name - ] = hyperparameters_to_keep_static + self.hyperparameters_to_keep_static_dict[estimator_name] = ( + hyperparameters_to_keep_static + ) if metric_definitions is not None: self.metric_definitions_dict[estimator_name] = metric_definitions @@ -2190,9 +2194,9 @@ def _get_tuner_args(cls, tuner, inputs): tuning_config["auto_parameters"] = tuner.auto_parameters if tuner.completion_criteria_config is not None: - tuning_config[ - "completion_criteria_config" - ] = tuner.completion_criteria_config.to_input_req() + tuning_config["completion_criteria_config"] = ( + tuner.completion_criteria_config.to_input_req() + ) tuner_args = { "job_name": tuner._current_job_name, @@ -2222,12 +2226,16 @@ def _get_tuner_args(cls, tuner, inputs): tuner.objective_type, tuner.objective_metric_name_dict[estimator_name], tuner.hyperparameter_ranges_dict()[estimator_name], - tuner.instance_configs_dict.get(estimator_name, None) - if tuner.instance_configs_dict is not None - else None, - tuner.auto_parameters_dict.get(estimator_name, None) - if tuner.auto_parameters_dict is not None - else None, + ( + tuner.instance_configs_dict.get(estimator_name, None) + if tuner.instance_configs_dict is not None + else None + ), + ( + tuner.auto_parameters_dict.get(estimator_name, None) + if tuner.auto_parameters_dict is not None + else None + ), ) for estimator_name in sorted(tuner.estimator_dict.keys()) ] @@ -2303,9 +2311,9 @@ def _prepare_training_config( training_config["image_uri"] = estimator.training_image_uri() training_config["enable_network_isolation"] = estimator.enable_network_isolation() - training_config[ - "encrypt_inter_container_traffic" - ] = estimator.encrypt_inter_container_traffic + training_config["encrypt_inter_container_traffic"] = ( + estimator.encrypt_inter_container_traffic + ) training_config["use_spot_instances"] = estimator.use_spot_instances training_config["checkpoint_s3_uri"] = estimator.checkpoint_s3_uri diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 793849ff93..3678c3d97e 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -68,13 +68,13 @@ def prepare_framework(estimator, s3_operations): ] estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script - estimator._hyperparameters[ - sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME - ] = estimator.container_log_level + estimator._hyperparameters[sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME] = ( + estimator.container_log_level + ) estimator._hyperparameters[sagemaker.model.JOB_NAME_PARAM_NAME] = estimator._current_job_name - estimator._hyperparameters[ - sagemaker.model.SAGEMAKER_REGION_PARAM_NAME - ] = estimator.sagemaker_session.boto_region_name + estimator._hyperparameters[sagemaker.model.SAGEMAKER_REGION_PARAM_NAME] = ( + estimator.sagemaker_session.boto_region_name + ) def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None): @@ -422,7 +422,7 @@ def _extract_training_config_list_from_estimator_dict( ) train_config_dict = {} - for (estimator_name, estimator) in tuner.estimator_dict.items(): + for estimator_name, estimator in tuner.estimator_dict.items(): train_config_dict[estimator_name] = training_base_config( estimator=estimator, inputs=inputs.get(estimator_name) if inputs else None, @@ -439,9 +439,9 @@ def _extract_training_config_list_from_estimator_dict( train_config.pop("HyperParameters", None) train_config["StaticHyperParameters"] = tuner.static_hyperparameters_dict[estimator_name] - train_config["AlgorithmSpecification"][ - "MetricDefinitions" - ] = tuner.metric_definitions_dict.get(estimator_name) + train_config["AlgorithmSpecification"]["MetricDefinitions"] = ( + tuner.metric_definitions_dict.get(estimator_name) + ) train_config["DefinitionName"] = estimator_name train_config["TuningObjective"] = { @@ -461,7 +461,7 @@ def _merge_s3_operations(s3_operations_list): """Merge a list of S3 operation dictionaries into one""" s3_operations_merged = {} for s3_operations in s3_operations_list: - for (key, operations) in s3_operations.items(): + for key, operations in s3_operations.items(): if key not in s3_operations_merged: s3_operations_merged[key] = [] for operation in operations: diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index c53e7e9666..11fbb2c00b 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -306,9 +306,9 @@ def to_request(self) -> RequestType: if isinstance( self.clarify_check_config, (ModelBiasCheckConfig, ModelExplainabilityCheckConfig) ): - request_dict[ - "ModelName" - ] = self.clarify_check_config.model_config.get_predictor_config()["model_name"] + request_dict["ModelName"] = ( + self.clarify_check_config.model_config.get_predictor_config()["model_name"] + ) return request_dict def _generate_processing_job_analysis_config(self) -> dict: @@ -348,9 +348,9 @@ def _generate_processing_job_analysis_config(self) -> dict: predictor_config.update(predicted_label_config) else: _set(model_scores, "label", predictor_config) - analysis_config[ - "methods" - ] = self.clarify_check_config.explainability_config.get_explainability_config() + analysis_config["methods"] = ( + self.clarify_check_config.explainability_config.get_explainability_config() + ) analysis_config["predictor"] = predictor_config return analysis_config diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 947578d433..3419cb1447 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -144,9 +144,11 @@ def expr(self): if self.s3_uri: return { "Std:JsonGet": { - "S3Uri": self.s3_uri.expr - if isinstance(self.s3_uri, PipelineVariable) - else self.s3_uri, + "S3Uri": ( + self.s3_uri.expr + if isinstance(self.s3_uri, PipelineVariable) + else self.s3_uri + ), "Path": self.json_path, } } diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 510ccd76bf..62167b96e7 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -395,9 +395,11 @@ def definition(self) -> str: "Version": self._version, "Metadata": self._metadata, "Parameters": list_to_request(self.parameters), - "PipelineExperimentConfig": self.pipeline_experiment_config.to_request() - if self.pipeline_experiment_config is not None - else None, + "PipelineExperimentConfig": ( + self.pipeline_experiment_config.to_request() + if self.pipeline_experiment_config is not None + else None + ), "Steps": list_to_request(compiled_steps), } diff --git a/tests/data/dummy_code_bundle_with_reqs/local_module.py b/tests/data/dummy_code_bundle_with_reqs/local_module.py index 542ab9f05d..03037550f9 100644 --- a/tests/data/dummy_code_bundle_with_reqs/local_module.py +++ b/tests/data/dummy_code_bundle_with_reqs/local_module.py @@ -1,2 +1,3 @@ """A dummy Python module to check importing local files works OK""" + DUMMY_CONSTANT = 1 diff --git a/tests/data/marketplace/iris/scoring_logic.py b/tests/data/marketplace/iris/scoring_logic.py index 03a89dc8f2..f9e2f1bb35 100644 --- a/tests/data/marketplace/iris/scoring_logic.py +++ b/tests/data/marketplace/iris/scoring_logic.py @@ -63,6 +63,7 @@ def predict(self, x, return_names=True): app = Flask(__name__) model = IrisModel(model_path="/opt/ml/model/model-artifacts.joblib") + # Create a path for health checks @app.route("/ping") def endpoint_ping(): diff --git a/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py b/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py index 37f6c4ccb9..5e017766e8 100644 --- a/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py +++ b/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py @@ -22,6 +22,7 @@ # --- this example demonstrates how to extend default behavior during model hosting --- + # --- Model preparation --- # it is possible to specify own code to load the model, otherwise a default model loading takes place def model_fn(path_to_model_files): diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py index 4a427b1898..64143395c8 100644 --- a/tests/data/pipeline/test_source_dir/script_1.py +++ b/tests/data/pipeline/test_source_dir/script_1.py @@ -1,6 +1,7 @@ """ Integ test file script_1.py """ + import pathlib if __name__ == "__main__": diff --git a/tests/data/tensorflow_mnist/mnist_v2.py b/tests/data/tensorflow_mnist/mnist_v2.py index 6b506fb58d..05467dee49 100644 --- a/tests/data/tensorflow_mnist/mnist_v2.py +++ b/tests/data/tensorflow_mnist/mnist_v2.py @@ -26,6 +26,7 @@ checkpoint manager. """ + # define a model class LeNet(tf.keras.Model): def __init__(self): diff --git a/tests/integ/auto_ml_v2_utils.py b/tests/integ/auto_ml_v2_utils.py index a227235877..65f304f338 100644 --- a/tests/integ/auto_ml_v2_utils.py +++ b/tests/integ/auto_ml_v2_utils.py @@ -118,9 +118,12 @@ def create_auto_ml_job_v2_if_not_exist(sagemaker_session, auto_ml_job_name, prob inputs = [ AutoMLDataChannel( s3_data_type="S3Prefix", - s3_uri=s3_uri - if DATA_CONFIGS[problem_type]["path"] != os.path.join(DATA_DIR, "cifar10_subset") - else s3_uri + "/", + s3_uri=( + s3_uri + if DATA_CONFIGS[problem_type]["path"] + != os.path.join(DATA_DIR, "cifar10_subset") + else s3_uri + "/" + ), channel_type="training", content_type=DATA_CONFIGS[problem_type]["content_type"], ) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index c905ff3a48..43db78527a 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -2221,9 +2221,9 @@ def test_ingest_in_memory_multi_process_with_collection_types( [3.0, 4.0], ["a", "b"], ] - pandas_data_frame_with_collection_type.loc[ - len(pandas_data_frame_with_collection_type) - ] = new_row_data + pandas_data_frame_with_collection_type.loc[len(pandas_data_frame_with_collection_type)] = ( + new_row_data + ) with pytest.raises(IngestionError): feature_group.ingest( data_frame=pandas_data_frame_with_collection_type, @@ -2284,9 +2284,9 @@ def test_ingest_in_memory_single_process_with_collection_types( [3.0, 4.0], ["a", "b"], ] - pandas_data_frame_with_collection_type.loc[ - len(pandas_data_frame_with_collection_type) - ] = new_row_data + pandas_data_frame_with_collection_type.loc[len(pandas_data_frame_with_collection_type)] = ( + new_row_data + ) with pytest.raises(IngestionError): feature_group.ingest( data_frame=pandas_data_frame_with_collection_type, @@ -2333,9 +2333,9 @@ def test_ingest_standard_multi_process_with_collection_types( [3.0, 4.0], ["a", "b"], ] - pandas_data_frame_with_collection_type.loc[ - len(pandas_data_frame_with_collection_type) - ] = new_row_data + pandas_data_frame_with_collection_type.loc[len(pandas_data_frame_with_collection_type)] = ( + new_row_data + ) ingestion_manager = feature_group.ingest( data_frame=pandas_data_frame_with_collection_type, diff --git a/tests/integ/test_serve.py b/tests/integ/test_serve.py index 951cfc68ca..08ad232b4a 100644 --- a/tests/integ/test_serve.py +++ b/tests/integ/test_serve.py @@ -129,6 +129,7 @@ def xgb_happy_infer(fns, model_path, x_test): # TODO: introduce cleanup option in serve settings so that we can clean up after ourselves + # XGB Integ Tests @pytest.mark.skipif( SKIP_COND_MET, reason="The goal of these test are to test the serving components of our feature" diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 07418f8ddb..fdc29b4d90 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -530,9 +530,9 @@ def test_jumpstart_validate_all_hyperparameters( ) assert str(e.value) == "Cannot find hyperparameter for 'sagemaker_submit_directory'." - hyperparameter_to_test[ - "sagemaker_submit_directory" - ] = "/opt/ml/input/data/code/sourcedir.tar.gz" + hyperparameter_to_test["sagemaker_submit_directory"] = ( + "/opt/ml/input/data/code/sourcedir.tar.gz" + ) del hyperparameter_to_test["epochs"] with pytest.raises(JumpStartHyperparametersError) as e: diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 5a11a5e88d..ce5f15b287 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1090,7 +1090,6 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( mock_attach.assert_not_called() def test_jumpstart_estimator_kwargs_match_parent_class(self): - """If you add arguments to , this test will fail. Please add the new argument to the skip set below, and reach out to JumpStart team.""" diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 231f7f2ad7..53119e532a 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -1104,9 +1104,9 @@ def _test_model_bias_monitor_update_schedule(model_bias_monitor, sagemaker_sessi assert model_bias_monitor.max_runtime_in_seconds == MAX_RUNTIME_IN_SECONDS assert model_bias_monitor.env == ENVIRONMENT assert model_bias_monitor.network_config == NETWORK_CONFIG - expected_arguments[ - "RoleArn" - ] = NEW_ROLE_ARN # all but role arn are from existing job definition + expected_arguments["RoleArn"] = ( + NEW_ROLE_ARN # all but role arn are from existing job definition + ) sagemaker_session.sagemaker_client.create_model_bias_job_definition.assert_called_once_with( **expected_arguments ) @@ -1627,9 +1627,9 @@ def _test_model_explainability_monitor_update_schedule( assert model_explainability_monitor.max_runtime_in_seconds == MAX_RUNTIME_IN_SECONDS assert model_explainability_monitor.env == ENVIRONMENT assert model_explainability_monitor.network_config == NETWORK_CONFIG - expected_arguments[ - "RoleArn" - ] = NEW_ROLE_ARN # all but role arn are from existing job definition + expected_arguments["RoleArn"] = ( + NEW_ROLE_ARN # all but role arn are from existing job definition + ) sagemaker_session.sagemaker_client.create_model_explainability_job_definition.assert_called_once_with( **expected_arguments ) diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 2035aa9d3a..d31b9f8527 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -1164,9 +1164,9 @@ def _test_data_quality_monitor_update_schedule(data_quality_monitor, sagemaker_s assert data_quality_monitor.max_runtime_in_seconds == MAX_RUNTIME_IN_SECONDS assert data_quality_monitor.env == ENVIRONMENT assert data_quality_monitor.network_config == NETWORK_CONFIG - expected_arguments[ - "RoleArn" - ] = NEW_ROLE_ARN # all but role arn are from existing job definition + expected_arguments["RoleArn"] = ( + NEW_ROLE_ARN # all but role arn are from existing job definition + ) sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_called_once_with( **expected_arguments ) @@ -1786,9 +1786,9 @@ def _test_model_quality_monitor_update_schedule(model_quality_monitor, sagemaker assert model_quality_monitor.max_runtime_in_seconds == MAX_RUNTIME_IN_SECONDS assert model_quality_monitor.env == ENVIRONMENT assert model_quality_monitor.network_config == NETWORK_CONFIG - expected_arguments[ - "RoleArn" - ] = NEW_ROLE_ARN # all but role arn are from existing job definition + expected_arguments["RoleArn"] = ( + NEW_ROLE_ARN # all but role arn are from existing job definition + ) sagemaker_session.sagemaker_client.create_model_quality_job_definition.assert_called_once_with( **expected_arguments ) diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index 98280af51b..a0742240ea 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -439,8 +439,7 @@ def test_serialize_deserialize_service_error(): def test_serialize_deserialize_exception_with_traceback(): s3_uri = random_s3_uri() - class CustomError(Exception): - ... + class CustomError(Exception): ... # noqa E701 def func_a(): raise TypeError @@ -469,8 +468,7 @@ def func_b(): def test_serialize_deserialize_custom_exception_with_traceback(): s3_uri = random_s3_uri() - class CustomError(Exception): - ... + class CustomError(Exception): ... # noqa: E701 def func_a(): raise TypeError @@ -500,8 +498,7 @@ def func_b(): def test_serialize_deserialize_remote_function_error_with_traceback(): s3_uri = random_s3_uri() - class CustomError(Exception): - ... + class CustomError(Exception): ... # noqa: E701 def func_a(): raise TypeError diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index eeadd5a495..1d199b7401 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -143,8 +143,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_detect_fw_version, ): # setup mocks - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_fw_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -154,14 +154,16 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_detect_fw_version.return_value = framework, version mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == session - and image_uri == mock_image_uri - and inference_spec is None - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == session + and image_uri == mock_image_uri + and inference_spec is None + else None + ) ) # Mock _ServeSettings @@ -172,13 +174,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -191,8 +191,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and image_config == MOCK_IMAGE_CONFIG and vpc_config == MOCK_VPC_CONFIG @@ -247,8 +247,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_detect_fw_version, ): # setup mocks - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_1p_dlc_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_1p_dlc_image_uri if model == mock_fw_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -257,14 +257,16 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_detect_fw_version.return_value = framework, version mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == session - and image_uri == mock_1p_dlc_image_uri - and inference_spec is None - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == session + and image_uri == mock_1p_dlc_image_uri + and inference_spec is None + else None + ) ) # Mock _ServeSettings @@ -275,13 +277,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -294,8 +294,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_1p_dlc_image_uri and model_data == model_data and role == mock_role_arn @@ -347,14 +347,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( ): # setup mocks mock_native_model = Mock() - mock_inference_spec.load = ( - lambda model_path: mock_native_model if model_path == MODEL_PATH else None + mock_inference_spec.load = lambda model_path: ( + mock_native_model if model_path == MODEL_PATH else None ) mock_detect_fw_version.return_value = framework, version - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_native_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -362,14 +362,16 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( ) mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == mock_session - and image_uri == mock_image_uri - and inference_spec == mock_inference_spec - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == mock_session + and image_uri == mock_image_uri + and inference_spec == mock_inference_spec + else None + ) ) # Mock _ServeSettings @@ -380,13 +382,13 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -399,8 +401,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn @@ -447,8 +449,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_detect_fw_version, ): # setup mocks - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_fw_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -458,14 +460,16 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_detect_fw_version.return_value = framework, version mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == session - and image_uri == mock_image_uri - and inference_spec is None - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == session + and image_uri == mock_image_uri + and inference_spec is None + else None + ) ) # Mock _ServeSettings @@ -476,13 +480,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -495,8 +497,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn @@ -551,8 +553,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_save_xgb, ): # setup mocks - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_fw_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -562,14 +564,16 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_detect_fw_version.return_value = "xgboost", version mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == session - and image_uri == mock_image_uri - and inference_spec is None - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == session + and image_uri == mock_image_uri + and inference_spec is None + else None + ) ) # Mock _ServeSettings @@ -580,13 +584,11 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -599,8 +601,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn @@ -655,12 +657,12 @@ def test_build_happy_path_with_local_container_mode( ): # setup mocks mock_native_model = Mock() - mock_inference_spec.load = ( - lambda model_path: mock_native_model if model_path == MODEL_PATH else None + mock_inference_spec.load = lambda model_path: ( + mock_native_model if model_path == MODEL_PATH else None ) - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_native_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -668,14 +670,16 @@ def test_build_happy_path_with_local_container_mode( ) mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == mock_session - and image_uri == mock_image_uri - and inference_spec == mock_inference_spec - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == mock_session + and image_uri == mock_image_uri + and inference_spec == mock_inference_spec + else None + ) ) # Mock _ServeSettings @@ -687,21 +691,23 @@ def test_build_happy_path_with_local_container_mode( mock_mode = Mock() mock_localContainerMode.side_effect = ( - lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: mock_mode - if inference_spec == mock_inference_spec - and schema_builder == schema_builder - and model_server == ModelServer.TORCHSERVE - and session == mock_session - and model_path == MODEL_PATH - and env_vars == {} - and model_server == ModelServer.TORCHSERVE - else None + lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: ( + mock_mode + if inference_spec == mock_inference_spec + and schema_builder == schema_builder + and model_server == ModelServer.TORCHSERVE + and session == mock_session + and model_path == MODEL_PATH + and env_vars == {} + and model_server == ModelServer.TORCHSERVE + else None + ) ) mock_mode.prepare.side_effect = lambda: None mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data is None and role == mock_role_arn @@ -751,14 +757,14 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ): # setup mocks mock_native_model = Mock() - mock_inference_spec.load = ( - lambda model_path: mock_native_model if model_path == MODEL_PATH else None + mock_inference_spec.load = lambda model_path: ( + mock_native_model if model_path == MODEL_PATH else None ) mock_detect_fw_version.return_value = framework, version - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_native_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -766,14 +772,16 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == mock_session - and image_uri == mock_image_uri - and inference_spec == mock_inference_spec - else None + lambda model_path, shared_libs, dependencies, session, image_uri, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == mock_session + and image_uri == mock_image_uri + and inference_spec == mock_inference_spec + else None + ) ) # Mock _ServeSettings @@ -785,26 +793,28 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo mock_lc_mode = Mock() mock_localContainerMode.side_effect = ( - lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: mock_lc_mode - if inference_spec == mock_inference_spec - and schema_builder == schema_builder - and model_server == ModelServer.TORCHSERVE - and session == mock_session - and model_path == MODEL_PATH - and env_vars == {} - and model_server == ModelServer.TORCHSERVE - else None + lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: ( + mock_lc_mode + if inference_spec == mock_inference_spec + and schema_builder == schema_builder + and model_server == ModelServer.TORCHSERVE + and session == mock_session + and model_path == MODEL_PATH + and env_vars == {} + and model_server == ModelServer.TORCHSERVE + else None + ) ) mock_lc_mode.prepare.side_effect = lambda: None mock_sagemaker_endpoint_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_sagemaker_endpoint_mode + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_sagemaker_endpoint_mode if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_sagemaker_endpoint_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -817,8 +827,8 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data is None and role == mock_role_arn @@ -848,8 +858,8 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo mock_predictor = Mock() builder._original_deploy = Mock() - builder._original_deploy.side_effect = ( - lambda *args, **kwargs: mock_predictor + builder._original_deploy.side_effect = lambda *args, **kwargs: ( + mock_predictor if kwargs.get("initial_instance_count") == 1 and kwargs.get("instance_type") == mock_instance_type else None @@ -893,8 +903,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co # setup mocks mock_detect_fw_version.return_value = framework, version - mock_detect_container.side_effect = ( - lambda model, region, instance_type: mock_image_uri + mock_detect_container.side_effect = lambda model, region, instance_type: ( + mock_image_uri if model == mock_fw_model and region == mock_session.boto_region_name and instance_type == "ml.c5.xlarge" @@ -902,14 +912,16 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co ) mock_prepare_for_torchserve.side_effect = ( - lambda model_path, shared_libs, dependencies, image_uri, session, inference_spec: mock_secret_key - if model_path == MODEL_PATH - and shared_libs == [] - and dependencies == {"auto": False} - and session == mock_session - and inference_spec is None - and image_uri == mock_image_uri - else None + lambda model_path, shared_libs, dependencies, image_uri, session, inference_spec: ( + mock_secret_key + if model_path == MODEL_PATH + and shared_libs == [] + and dependencies == {"auto": False} + and session == mock_session + and inference_spec is None + and image_uri == mock_image_uri + else None + ) ) # Mock _ServeSettings @@ -920,13 +932,11 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = ( - lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + ( model_data, ENV_VAR_PAIR, ) @@ -940,27 +950,31 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_lc_mode = Mock() mock_localContainerMode.side_effect = ( - lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: mock_lc_mode - if inference_spec is None - and schema_builder == schema_builder - and model_server == ModelServer.TORCHSERVE - and session == mock_session - and model_path == MODEL_PATH - and env_vars == ENV_VARS - else None + lambda inference_spec, schema_builder, session, model_path, env_vars, model_server: ( + mock_lc_mode + if inference_spec is None + and schema_builder == schema_builder + and model_server == ModelServer.TORCHSERVE + and session == mock_session + and model_path == MODEL_PATH + and env_vars == ENV_VARS + else None + ) ) mock_lc_mode.prepare.side_effect = lambda: None mock_lc_mode.create_server.side_effect = ( - lambda image_uri, container_timeout_seconds, secret_key, predictor: None - if image_uri == mock_image_uri - and secret_key == mock_secret_key - and container_timeout_seconds == 60 - else None + lambda image_uri, container_timeout_seconds, secret_key, predictor: ( + None + if image_uri == mock_image_uri + and secret_key == mock_secret_key + and container_timeout_seconds == 60 + else None + ) ) mock_model_obj = Mock() - mock_sdk_model.side_effect = ( - lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn @@ -1708,10 +1722,8 @@ def test_build_mlflow_model_local_input_happy( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) mock_mode.prepare.return_value = ( model_data, @@ -1785,10 +1797,8 @@ def test_build_mlflow_model_s3_input_happy( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) mock_mode.prepare.return_value = ( model_data, @@ -1861,10 +1871,8 @@ def test_build_mlflow_model_s3_input_non_mlflow_case( mock_path_exists.side_effect = lambda path: True if path == "test_path" else False mock_mode = Mock() - mock_sageMakerEndpointMode.side_effect = ( - lambda inference_spec, model_server: mock_mode - if inference_spec is None and model_server == ModelServer.TORCHSERVE - else None + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) mock_mode.prepare.return_value = ( model_data, diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_server.py b/tests/unit/sagemaker/serve/model_server/triton/test_server.py index 0e6ea624d7..c80c4296e7 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_server.py @@ -40,8 +40,8 @@ class TritonServerTests(TestCase): @patch("sagemaker.serve.model_server.triton.server.importlib") def test_start_invoke_destroy_local_triton_server_gpu(self, mock_importlib): mock_triton_client = Mock() - mock_importlib.import_module.side_effect = ( - lambda module_name: mock_triton_client if module_name == "tritonclient.http" else None + mock_importlib.import_module.side_effect = lambda module_name: ( + mock_triton_client if module_name == "tritonclient.http" else None ) mock_container = Mock() @@ -99,8 +99,8 @@ def test_start_invoke_destroy_local_triton_server_gpu(self, mock_importlib): @patch("sagemaker.serve.model_server.triton.server.importlib") def test_start_invoke_destroy_local_triton_server_cpu(self, mock_importlib): mock_triton_client = Mock() - mock_importlib.import_module.side_effect = ( - lambda module_name: mock_triton_client if module_name == "tritonclient.http" else None + mock_importlib.import_module.side_effect = lambda module_name: ( + mock_triton_client if module_name == "tritonclient.http" else None ) mock_container = Mock() @@ -158,8 +158,8 @@ def test_start_invoke_destroy_local_triton_server_cpu(self, mock_importlib): def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platform): mock_session = Mock() mock_platform.python_version.return_value = "3.8" - mock_upload.side_effect = ( - lambda session, repo, bucket, prefix: S3_URI + mock_upload.side_effect = lambda session, repo, bucket, prefix: ( + S3_URI if session == mock_session and repo == MODEL_PATH + "/model_repository" and bucket == "mock_model_data_uri" diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 475639bd1d..de91703d63 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -554,9 +554,9 @@ def test_fit_mwms( expected_train_args = _create_train_job(framework_version, py_version=py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs - expected_train_args[ - "image_uri" - ] = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}" + expected_train_args["image_uri"] = ( + f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}" + ) expected_train_args["job_name"] = f"tensorflow-training-{TIMESTAMP}" expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True) expected_train_args["hyperparameters"]["sagemaker_job_name"] = json.dumps( diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index a165fa523e..ac42bb53ab 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -299,9 +299,9 @@ def test_default( ) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["enable_sagemaker_metrics"] = False - expected_train_args["hyperparameters"][ - TrainingCompilerConfig.HP_ENABLE_COMPILER - ] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = ( + json.dumps(True) + ) expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( False ) @@ -358,9 +358,9 @@ def test_byoc( ) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["enable_sagemaker_metrics"] = False - expected_train_args["hyperparameters"][ - TrainingCompilerConfig.HP_ENABLE_COMPILER - ] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = ( + json.dumps(True) + ) expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( False ) @@ -409,9 +409,9 @@ def test_debug_compiler_config( ) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["enable_sagemaker_metrics"] = False - expected_train_args["hyperparameters"][ - TrainingCompilerConfig.HP_ENABLE_COMPILER - ] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = ( + json.dumps(True) + ) expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( True ) @@ -460,9 +460,9 @@ def test_disable_compiler_config( ) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["enable_sagemaker_metrics"] = False - expected_train_args["hyperparameters"][ - TrainingCompilerConfig.HP_ENABLE_COMPILER - ] = json.dumps(False) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = ( + json.dumps(False) + ) expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( False ) diff --git a/tests/unit/sagemaker/workflow/test_retry.py b/tests/unit/sagemaker/workflow/test_retry.py index cf58d615e4..9d871317a1 100644 --- a/tests/unit/sagemaker/workflow/test_retry.py +++ b/tests/unit/sagemaker/workflow/test_retry.py @@ -115,7 +115,7 @@ def test_invalid_retry_policy(): (5, 2.0, 10, 30), ] - for (interval_sec, backoff_rate, max_attempts, expire_after) in retry_policies: + for interval_sec, backoff_rate, max_attempts, expire_after in retry_policies: try: RetryPolicy( interval_seconds=interval_sec, diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 82ac2cd8bc..43d52e521d 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -846,9 +846,7 @@ def test_valid_model_config(content_type, accept_type): content_template = ( '{"instances":$features}' if content_type == "application/jsonlines" - else "$records" - if content_type == "application/json" - else None + else "$records" if content_type == "application/json" else None ) record_template = "$features_kvp" if content_type == "application/json" else None model_config = ModelConfig( @@ -1071,9 +1069,7 @@ def test_model_config_with_time_series(self, content_type, accept_type): content_template = ( '{"instances":$features}' if content_type == "application/jsonlines" - else "$records" - if content_type == "application/json" - else None + else "$records" if content_type == "application/json" else None ) record_template = "$features_kvp" if content_type == "application/json" else None # create mock config for TimeSeriesModelConfig diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 8790f158a8..9f0d68f01d 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -97,9 +97,9 @@ def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, ) sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error # bucket exists - sagemaker_session.boto_session.resource("s3").Bucket( - name=DEFAULT_BUCKET_NAME - ).creation_date = datetime_obj + sagemaker_session.boto_session.resource("s3").Bucket(name=DEFAULT_BUCKET_NAME).creation_date = ( + datetime_obj + ) # This should not raise ClientError as no head_bucket call is expected for custom bucket sagemaker_session.default_bucket() assert sagemaker_session._default_bucket == "custom-bucket-override" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0b9d62ca83..19f9d0ae3d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2460,9 +2460,9 @@ def boto_session_complete(): boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS boto_mock.client("sagemaker").describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - boto_mock.client( - "sagemaker" - ).describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + boto_mock.client("sagemaker").describe_transform_job.return_value = ( + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + ) return boto_mock @@ -2482,9 +2482,9 @@ def boto_session_stopped(): boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS boto_mock.client("sagemaker").describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT - boto_mock.client( - "sagemaker" - ).describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT + boto_mock.client("sagemaker").describe_transform_job.return_value = ( + STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT + ) return boto_mock diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index ff39535adf..f0325b79e9 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -830,7 +830,7 @@ def _assert_parameter_ranges(expected, actual, is_framework_estimator): continuous_ranges = [] integer_ranges = [] categorical_ranges = [] - for (name, param_range) in expected.items(): + for name, param_range in expected.items(): if isinstance(param_range, ContinuousParameter): continuous_ranges.append(param_range.as_tuning_range(name)) elif isinstance(param_range, IntegerParameter):