diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index eb2b5c4087..875ef44d8f 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64"] +requires = ["setuptools>=64,<77"] build-backend = "setuptools.build_meta" [project] @@ -40,10 +40,10 @@ dependencies = [ "tblib>=1.7.0", ] requires-python = ">=3.9" +license = {text = "Apache-2.0"} classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/sagemaker-core/src/sagemaker/core/remote_function/client.py b/sagemaker-core/src/sagemaker/core/remote_function/client.py index 3cfa5e3b23..85e2cda868 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/client.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/client.py @@ -369,7 +369,6 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), - ) except ServiceError as serr: chained_e = serr.__cause__ @@ -406,7 +405,6 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), - ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -1008,7 +1006,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), - ) except DeserializationError as e: client_exception = e @@ -1020,7 +1017,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), - ) except ServiceError as serr: chained_e = serr.__cause__ @@ -1110,7 +1106,6 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), - ) self._state = _FINISHED return self._return @@ -1119,7 +1114,6 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), - ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index 8871f6727f..6a4aecfab0 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -15,14 +15,17 @@ import dataclasses import json +import logging import io import sys import hashlib +import hmac import pickle +import secrets -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import cloudpickle from tblib import pickling_support @@ -38,6 +41,8 @@ # Note: do not use os.path.join for s3 uris, fails on windows +logger = logging.getLogger(__name__) + def _get_python_version(): """Returns the current python version.""" @@ -49,6 +54,7 @@ class _MetaData: """Metadata about the serialized data or functions.""" sha256_hash: str + secret_arn: Optional[str] = None # ARN to AWS Secrets Manager secret containing HMAC key version: str = "2023-04-24" python_version: str = _get_python_version() serialization_module: str = "cloudpickle" @@ -66,7 +72,8 @@ def from_json(s): raise DeserializationError("Corrupt metadata file. It is not a valid json file.") sha256_hash = obj.get("sha256_hash") - metadata = _MetaData(sha256_hash=sha256_hash) + secret_arn = obj.get("secret_arn") # May be None for legacy format + metadata = _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn) metadata.version = obj.get("version") metadata.python_version = obj.get("python_version") metadata.serialization_module = obj.get("serialization_module") @@ -155,16 +162,21 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + func: Callable, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes function and uploads it to S3. Args: + func: function to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ @@ -173,6 +185,7 @@ def serialize_func_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -199,23 +212,31 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + obj: Any, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes data object and uploads it to S3. Args: + obj: object to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -224,6 +245,7 @@ def serialize_obj_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -288,23 +310,31 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + exc: Exception, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. Args: + exc: Exception to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -314,6 +344,7 @@ def serialize_exception_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -322,6 +353,7 @@ def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], s3_uri: str, sagemaker_session: Session, + job_name: str, s3_kms_key, ): """Uploads serialized payload and metadata to s3. @@ -331,14 +363,22 @@ def _upload_payload_and_metadata_to_s3( s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload) + # Get or create HMAC secret in Secrets Manager + secret_arn, hmac_key = _get_or_create_hmac_secret(sagemaker_session, job_name) + + # Compute HMAC-SHA256 hash + sha256_hash = _compute_hmac(bytes_to_upload, hmac_key) + + # Store secret ARN in Parameter Store as trust anchor (Mitigation #3) + _store_secret_arn_in_parameter_store(sagemaker_session, job_name, secret_arn) _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), + _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn).to_json(), f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, @@ -365,7 +405,10 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -396,15 +439,260 @@ def _compute_hash(buffer: bytes) -> str: return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, buffer: bytes): +def _get_or_create_hmac_secret(sagemaker_session: Session, job_name: str) -> tuple[str, str]: + """Get or create HMAC key in AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Tuple of (secret_arn, hmac_key) + """ + secret_name = f"sagemaker/remote-function/{job_name}/hmac-key" + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + + try: + # Try to retrieve existing secret + response = secrets_client.get_secret_value(SecretId=secret_name) + return response['ARN'], response['SecretString'] + except secrets_client.exceptions.ResourceNotFoundException: + # Create new secret + hmac_key = secrets.token_hex(32) + + response = secrets_client.create_secret( + Name=secret_name, + SecretString=hmac_key, + Description=f"HMAC key for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + return response['ARN'], hmac_key + + +def _get_hmac_key_from_secret(sagemaker_session: Session, secret_arn: str) -> str: + """Retrieve HMAC key from AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + secret_arn: ARN of the secret containing HMAC key + + Returns: + HMAC key string + """ + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + response = secrets_client.get_secret_value(SecretId=secret_arn) + return response['SecretString'] + + +def _compute_hmac(buffer: bytes, hmac_key: str) -> str: + """Compute HMAC-SHA256 hash. + + Args: + buffer: Data to hash + hmac_key: HMAC secret key + + Returns: + HMAC-SHA256 hex digest + """ + return hmac.new(hmac_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() + + +def _store_secret_arn_in_parameter_store( + sagemaker_session: Session, + job_name: str, + secret_arn: str +): + """Store secret ARN in Parameter Store as trust anchor. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + secret_arn: ARN of the secret to store + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + try: + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Description=f"Secret ARN for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + except ssm_client.exceptions.ParameterAlreadyExists: + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Overwrite=True, + ) + + +def _get_secret_arn_from_parameter_store( + sagemaker_session: Session, + job_name: str +) -> str: + """Retrieve secret ARN from Parameter Store. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Secret ARN string + + Raises: + DeserializationError: If parameter not found + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + try: + response = ssm_client.get_parameter(Name=parameter_name) + return response['Parameter']['Value'] + except ssm_client.exceptions.ParameterNotFound: + raise DeserializationError( + f"Secret ARN not found in Parameter Store for job {job_name}. " + "This may indicate the job was not properly initialized or artifacts were tampered with." + ) + + +def _extract_job_name_from_secret_arn(secret_arn: str) -> str: + """Extract job name from a Secrets Manager ARN. + + Secret name convention: sagemaker/remote-function/{job_name}/hmac-key + ARN format: arn:aws:secretsmanager:region:account:secret:sagemaker/remote-function/{job_name}/hmac-key-XXXXXX + + Args: + secret_arn: Full ARN of the secret + + Returns: + Extracted job name + + Raises: + DeserializationError: If ARN doesn't match expected format + """ + # Length guard to prevent ReDoS on crafted inputs. + # Real ARNs are ~165 chars (job names are max 63 chars per SageMaker API). + MAX_SECRET_ARN_LENGTH = 256 + if len(secret_arn) > MAX_SECRET_ARN_LENGTH: + raise DeserializationError( + f"Secret ARN exceeds maximum length of {MAX_SECRET_ARN_LENGTH} characters" + ) + + import re + # Use [^/]+ (non-greedy, no slashes) to prevent path-traversal in job name, + # and anchor with $ to ensure the ARN ends with the expected suffix. + match = re.search( + r":secret:sagemaker/remote-function/([^/]+)/hmac-key-[A-Za-z0-9]{6}$", secret_arn + ) + if not match: + raise DeserializationError( + f"Secret ARN does not match expected format " + f"'sagemaker/remote-function/{{job_name}}/hmac-key-XXXXXX': {secret_arn}" + ) + return match.group(1) + + +def _validate_secret_arn( + sagemaker_session: Session, + metadata_secret_arn: str, +): + """Validate secret ARN from metadata against trusted sources. + + Implements two mitigations: + 1. Validate secret is in same AWS account + 2. Validate secret ARN matches Parameter Store (trust anchor) + + The job_name is derived from the secret ARN's naming convention, then + independently validated against the SSM trust anchor. + + Args: + sagemaker_session: SageMaker session + metadata_secret_arn: Secret ARN from S3 metadata (untrusted) + + Raises: + DeserializationError: If validation fails + """ + # Mitigation #1: Validate same account + sts_client = sagemaker_session.boto_session.client('sts') + current_account_id = sts_client.get_caller_identity()['Account'] + + # Parse account ID from ARN: arn:aws:secretsmanager:region:ACCOUNT_ID:secret:name + arn_parts = metadata_secret_arn.split(":") + if len(arn_parts) < 5: + raise DeserializationError(f"Invalid secret ARN format: {metadata_secret_arn}") + + metadata_account_id = arn_parts[4] + + if metadata_account_id != current_account_id: + raise DeserializationError( + f"Secret must be in the same AWS account. " + f"Expected account {current_account_id}, but got {metadata_account_id}. " + "This may indicate a cross-account attack attempt." + ) + + # Mitigation #3: Validate against Parameter Store (trust anchor) + job_name = _extract_job_name_from_secret_arn(metadata_secret_arn) + expected_secret_arn = _get_secret_arn_from_parameter_store(sagemaker_session, job_name) + + if metadata_secret_arn != expected_secret_arn: + raise DeserializationError( + f"Secret ARN mismatch. Expected: {expected_secret_arn}, " + f"Got: {metadata_secret_arn}. " + "Possible tampering detected - metadata may have been modified." + ) + + +def _perform_integrity_check( + expected_hash_value: str, + buffer: bytes, + sagemaker_session: Optional[Session] = None, + secret_arn: Optional[str] = None, +): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. + + Args: + expected_hash_value: Expected hash value from metadata + buffer: Serialized data buffer + sagemaker_session: SageMaker session (required for HMAC integrity check) + secret_arn: ARN of secret containing HMAC key (required) + + Raises: + DeserializationError: If integrity check fails or secret_arn is missing """ - actual_hash_value = _compute_hash(buffer=buffer) - if expected_hash_value != actual_hash_value: + if not secret_arn: + raise DeserializationError( + "Missing secret_arn in metadata. HMAC integrity check is required. " + "Legacy SHA-256 integrity check is no longer supported due to security " + "vulnerabilities. Please upgrade to the latest SDK version on both " + "client and remote sides." + ) + + if not sagemaker_session: + raise DeserializationError( + "sagemaker_session is required for HMAC integrity check" + ) + + # Validate secret ARN (Mitigations #1 and #3) + _validate_secret_arn(sagemaker_session, secret_arn) + + # Now safe to retrieve HMAC key + hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) + actual_hash_value = _compute_hmac(buffer, hmac_key) + + if not hmac.compare_digest(expected_hash_value, actual_hash_value): raise DeserializationError( - "Integrity check for the serialized function or data failed. " - "Please restrict access to your S3 bucket" + "HMAC integrity check failed. Serialized data may have been tampered with. " + "Please restrict access to your S3 bucket." ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index c7ee86f8a7..d09c3737f5 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -57,6 +57,7 @@ def __init__( s3_base_uri: str, s3_kms_key: str = None, context: Context = Context(), + job_name: str = None, ): """Construct a StoredFunction object. @@ -66,11 +67,13 @@ def __init__( s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. context: Build or run context of a pipeline step. + job_name: Remote function job name for secret management. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.context = context + self.job_name = job_name or os.environ.get("TRAINING_JOB_NAME") # For pipeline steps, function code is at: base/step_name/build_timestamp/ # For results, path is: base/step_name/build_timestamp/execution_id/ @@ -110,6 +113,7 @@ def save(self, func, *args, **kwargs): func=func, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -123,7 +127,7 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -141,7 +145,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.func, - + job_name=self.job_name, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -153,7 +157,7 @@ def save_pipeline_step_function(self, serialized_data): ) serialization._upload_payload_and_metadata_to_s3( bytes_to_upload=serialized_data.args, - + job_name=self.job_name, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, s3_kms_key=self.s3_kms_key, @@ -203,7 +207,7 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py index 3f391570cf..6315c1c527 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, job_name=None) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,6 +79,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + job_name (str): Remote function job name for secret management. Returns : exit_code (int): Exit code to terminate current job. """ @@ -96,6 +97,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), + job_name=job_name or os.environ.get("TRAINING_JOB_NAME"), s3_kms_key=s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py index 2e69f4f116..c43978f687 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py @@ -108,6 +108,7 @@ def _execute_remote_function( s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, context=context, + job_name=os.environ.get("TRAINING_JOB_NAME"), ) if run_in_context: diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index 6e727d4b9c..b6ac5572b7 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -931,6 +931,7 @@ def compile( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, + job_name=job_name, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -942,6 +943,7 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), + job_name=job_name, ) stored_function.save_pipeline_step_function(serialized_data) diff --git a/sagemaker-core/tests/integ/remote_function/conftest.py b/sagemaker-core/tests/integ/remote_function/conftest.py new file mode 100644 index 0000000000..8b00caa794 --- /dev/null +++ b/sagemaker-core/tests/integ/remote_function/conftest.py @@ -0,0 +1,63 @@ +"""Shared fixtures for remote function integration tests.""" + +import os +import shutil +import tempfile + +import cloudpickle +import pytest + +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3 import S3Uploader + + +def _get_repo_root(): + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +def _upload_core_source(sagemaker_session): + """Tar the sagemaker-core source and upload to S3. Returns (s3_prefix, tar_basename).""" + repo_root = _get_repo_root() + core_dir = os.path.join(repo_root, "sagemaker-core") + dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_src_") + + archive_path = shutil.make_archive( + os.path.join(dist_dir, "sagemaker-core-src"), "gztar", root_dir=core_dir, base_dir="." + ) + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/remote-function-test/src" + S3Uploader.upload(archive_path, s3_prefix, sagemaker_session=sagemaker_session) + + return s3_prefix, os.path.basename(archive_path) + + +@pytest.fixture(scope="module") +def sagemaker_session(): + import boto3 + return Session(boto3.Session()) + + +@pytest.fixture(scope="module") +def role(sagemaker_session): + import boto3 + account_id = boto3.client("sts").get_caller_identity()["Account"] + return f"arn:aws:iam::{account_id}:role/Admin" + + +@pytest.fixture(scope="module") +def image_uri(sagemaker_session): + region = sagemaker_session.boto_region_name + return f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + + +@pytest.fixture(scope="module") +def dev_sdk_pre_execution_commands(sagemaker_session): + """Upload dev sagemaker-core source to S3 and return pre_execution_commands.""" + s3_prefix, tar_name = _upload_core_source(sagemaker_session) + cp_version = cloudpickle.__version__ + return [ + f"pip install cloudpickle=={cp_version}", + f"aws s3 cp {s3_prefix}/{tar_name} /tmp/{tar_name}", + "mkdir -p /tmp/sagemaker-core-src && tar xzf /tmp/{tar_name} -C /tmp/sagemaker-core-src".format(tar_name=tar_name), + "pip install --no-deps /tmp/sagemaker-core-src", + ] diff --git a/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py index b3d38c32a4..61d26f78e8 100644 --- a/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py +++ b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py @@ -9,12 +9,6 @@ import tempfile import pytest -# Skip decorator for AWS configuration -# skip_if_no_aws_region = pytest.mark.skipif( -# not os.environ.get('AWS_DEFAULT_REGION'), -# reason="AWS credentials not configured" -# ) - # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) @@ -25,58 +19,47 @@ class TestRemoteFunctionDependencyInjection: """Integration tests for dependency injection in remote functions.""" @pytest.mark.integ - # @skip_if_no_aws_region - def test_remote_function_without_dependencies(self): - """Test remote function execution without explicit dependencies. - - This test verifies that when no dependencies are provided, the remote - function still executes successfully because sagemaker>=3.2.0 is - automatically injected. - """ + def test_remote_function_without_dependencies( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test remote function execution without explicit dependencies.""" @remote( instance_type="ml.m5.large", - # No dependencies specified - sagemaker should be injected automatically + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def simple_add(x, y): - """Simple function that adds two numbers.""" return x + y - - # Execute the function + result = simple_add(5, 3) - - # Verify result assert result == 8, f"Expected 8, got {result}" - print("✓ Remote function without dependencies executed successfully") @pytest.mark.integ - # @skip_if_no_aws_region - def test_remote_function_with_user_dependencies_no_sagemaker(self): - """Test remote function with user dependencies but no sagemaker. - - This test verifies that when user provides dependencies without sagemaker, - sagemaker>=3.2.0 is automatically appended. - """ - # Create a temporary requirements.txt without sagemaker + def test_remote_function_with_user_dependencies_no_sagemaker( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test remote function with user dependencies but no sagemaker.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: f.write("numpy>=1.20.0\npandas>=1.3.0\n") req_file = f.name - + try: @remote( instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, dependencies=req_file, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def compute_with_numpy(x): - """Function that uses numpy.""" import numpy as np return np.array([x, x*2, x*3]).sum() - - # Execute the function + result = compute_with_numpy(5) - - # Verify result (5 + 10 + 15 = 30) assert result == 30, f"Expected 30, got {result}" - print("✓ Remote function with user dependencies executed successfully") finally: os.remove(req_file) @@ -85,52 +68,55 @@ class TestRemoteFunctionVersionCompatibility: """Tests for version compatibility between local and remote environments.""" @pytest.mark.integ - # @skip_if_no_aws_region - def test_deserialization_with_injected_sagemaker(self): - """Test that deserialization works with injected sagemaker dependency. - - This test verifies that the remote environment can properly deserialize - functions when sagemaker>=3.2.0 is available. - """ + def test_deserialization_with_injected_sagemaker( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test that deserialization works with injected sagemaker dependency.""" @remote( instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, ) def complex_computation(data): - """Function that performs complex computation.""" result = sum(data) * len(data) return result - - # Execute with various data types + test_data = [1, 2, 3, 4, 5] result = complex_computation(test_data) - - # Verify result (sum=15, len=5, 15*5=75) assert result == 75, f"Expected 75, got {result}" - print("✓ Deserialization with injected sagemaker works correctly") @pytest.mark.integ - # @skip_if_no_aws_region - def test_multiple_remote_functions_with_dependencies(self): - """Test multiple remote functions with different dependency configurations. - - This test verifies that the dependency injection works correctly - when multiple remote functions are defined and executed. - """ - @remote(instance_type="ml.m5.large") + def test_multiple_remote_functions_with_dependencies( + self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session + ): + """Test multiple remote functions with different dependency configurations.""" + @remote( + instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, + ) def func1(x): return x + 1 - - @remote(instance_type="ml.m5.large") + + @remote( + instance_type="ml.m5.large", + role=role, + image_uri=image_uri, + sagemaker_session=sagemaker_session, + pre_execution_commands=dev_sdk_pre_execution_commands, + ) def func2(x): return x * 2 - - # Execute both functions + result1 = func1(5) result2 = func2(5) - + assert result1 == 6, f"func1: Expected 6, got {result1}" assert result2 == 10, f"func2: Expected 10, got {result2}" - print("✓ Multiple remote functions with dependencies executed successfully") if __name__ == "__main__": diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py index 4810eba2e0..7bd24489e7 100644 --- a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py +++ b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py @@ -188,6 +188,7 @@ def test_executes_without_run_context(self, mock_stored_function_class): s3_base_uri="s3://bucket/path", s3_kms_key="key-123", context=mock_context, + job_name=None, ) mock_stored_func.load_and_invoke.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py new file mode 100644 index 0000000000..d01d64cea1 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py @@ -0,0 +1,391 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for serialization security (HMAC + Secrets Manager + Parameter Store).""" +from __future__ import absolute_import + +import hashlib +import hmac as hmac_module +import json +from unittest.mock import Mock, patch, MagicMock + +import pytest + +from sagemaker.core.remote_function.core.serialization import ( + _MetaData, + _compute_hash, + _compute_hmac, + _extract_job_name_from_secret_arn, + _get_or_create_hmac_secret, + _get_hmac_key_from_secret, + _store_secret_arn_in_parameter_store, + _get_secret_arn_from_parameter_store, + _validate_secret_arn, + _perform_integrity_check, + _upload_payload_and_metadata_to_s3, + serialize_obj_to_s3, + deserialize_obj_from_s3, + serialize_func_to_s3, + serialize_exception_to_s3, + deserialize_func_from_s3, + deserialize_exception_from_s3, +) +from sagemaker.core.remote_function.errors import DeserializationError + + +MOCK_JOB_NAME = "test-remote-function-job" +MOCK_SECRET_ARN = "arn:aws:secretsmanager:us-west-2:123456789012:secret:sagemaker/remote-function/test-remote-function-job/hmac-key-AbCdEf" +MOCK_HMAC_KEY = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" +MOCK_ACCOUNT_ID = "123456789012" +MOCK_S3_URI = "s3://my-bucket/remote-function/test-remote-function-job/results" + + +def _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID): + """Create a mock SageMaker session with Secrets Manager, SSM, and STS clients.""" + session = Mock() + + # Mock Secrets Manager client + secrets_client = Mock() + secrets_client.get_secret_value.return_value = { + "ARN": MOCK_SECRET_ARN, + "SecretString": MOCK_HMAC_KEY, + } + secrets_client.create_secret.return_value = { + "ARN": MOCK_SECRET_ARN, + } + secrets_client.exceptions = Mock() + secrets_client.exceptions.ResourceNotFoundException = type( + "ResourceNotFoundException", (Exception,), {} + ) + + # Mock SSM client + ssm_client = Mock() + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + ssm_client.exceptions = Mock() + ssm_client.exceptions.ParameterNotFound = type( + "ParameterNotFound", (Exception,), {} + ) + + # Mock STS client + sts_client = Mock() + sts_client.get_caller_identity.return_value = {"Account": account_id} + + def client_factory(service_name): + if service_name == "secretsmanager": + return secrets_client + elif service_name == "ssm": + return ssm_client + elif service_name == "sts": + return sts_client + return Mock() + + session.boto_session.client = client_factory + return session, secrets_client, ssm_client, sts_client + + +class TestMetaData: + """Tests for _MetaData class.""" + + def test_metadata_with_secret_arn(self): + metadata = _MetaData(sha256_hash="abc123", secret_arn=MOCK_SECRET_ARN) + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn == MOCK_SECRET_ARN + + def test_metadata_without_secret_arn_legacy(self): + metadata = _MetaData(sha256_hash="abc123") + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn is None + + def test_metadata_missing_hash_raises(self): + with pytest.raises(DeserializationError, match="SHA256 hash"): + _MetaData.from_json(json.dumps({"version": "2023-04-24", "serialization_module": "cloudpickle"})) + + def test_metadata_invalid_json_raises(self): + with pytest.raises(DeserializationError, match="not a valid json"): + _MetaData.from_json(b"not json") + + +class TestComputeHmac: + """Tests for HMAC computation.""" + + def test_compute_hmac(self): + data = b"test data" + key = "test-key" + result = _compute_hmac(data, key) + expected = hmac_module.new(key.encode(), msg=data, digestmod=hashlib.sha256).hexdigest() + assert result == expected + + def test_compute_hmac_different_keys_produce_different_hashes(self): + data = b"test data" + hash1 = _compute_hmac(data, "key1") + hash2 = _compute_hmac(data, "key2") + assert hash1 != hash2 + + def test_compute_hash_plain_sha256(self): + data = b"test data" + result = _compute_hash(data) + expected = hashlib.sha256(data).hexdigest() + assert result == expected + + +class TestGetOrCreateHmacSecret: + """Tests for Secrets Manager integration.""" + + def test_get_existing_secret(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert key == MOCK_HMAC_KEY + secrets_client.get_secret_value.assert_called_once_with( + SecretId=f"sagemaker/remote-function/{MOCK_JOB_NAME}/hmac-key" + ) + + def test_create_new_secret_when_not_found(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + # Simulate ResourceNotFoundException + secrets_client.get_secret_value.side_effect = ( + secrets_client.exceptions.ResourceNotFoundException("not found") + ) + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert len(key) == 64 # secrets.token_hex(32) produces 64 chars + secrets_client.create_secret.assert_called_once() + + +class TestParameterStore: + """Tests for Parameter Store trust anchor.""" + + def test_store_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + _store_secret_arn_in_parameter_store(session, MOCK_JOB_NAME, MOCK_SECRET_ARN) + + ssm_client.put_parameter.assert_called_once() + call_kwargs = ssm_client.put_parameter.call_args[1] + assert call_kwargs["Name"] == f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + assert call_kwargs["Value"] == MOCK_SECRET_ARN + assert "Tags" in call_kwargs + + def test_get_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + result = _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + assert result == MOCK_SECRET_ARN + ssm_client.get_parameter.assert_called_once_with( + Name=f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + ) + + def test_get_secret_arn_not_found_raises(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + ssm_client.get_parameter.side_effect = ( + ssm_client.exceptions.ParameterNotFound("not found") + ) + + with pytest.raises(DeserializationError, match="Secret ARN not found"): + _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + +class TestValidateSecretArn: + """Tests for secret ARN validation (Mitigations #1 and #3).""" + + def test_valid_secret_arn_passes(self): + """Valid ARN in same account matching Parameter Store should pass.""" + session, _, _, _ = _mock_sagemaker_session() + + # Should not raise + _validate_secret_arn(session, MOCK_SECRET_ARN) + + def test_cross_account_arn_rejected(self): + """Mitigation #1: Secret ARN from different account should be rejected.""" + session, _, _, _ = _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID) + + attacker_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:evil-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_arn) + + def test_tampered_arn_rejected(self): + """Mitigation #3: ARN not matching Parameter Store should be rejected.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store returns the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's ARN (same account but different secret) + tampered_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:attacker-created-secret" + + with pytest.raises(DeserializationError, match="does not match expected format"): + _validate_secret_arn(session, tampered_arn) + + def test_invalid_arn_format_rejected(self): + """Malformed ARN should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + with pytest.raises(DeserializationError, match="Invalid secret ARN format"): + _validate_secret_arn(session, "not-an-arn") + + +class TestExtractJobNameFromSecretArn: + """Tests for _extract_job_name_from_secret_arn regex hardening.""" + + def test_valid_arn(self): + result = _extract_job_name_from_secret_arn(MOCK_SECRET_ARN) + assert result == MOCK_JOB_NAME + + def test_rejects_greedy_path_traversal(self): + """Greedy .+ allowed evil/hmac-key/../ in job name — now rejected.""" + malicious_arn = ( + "arn:aws:secretsmanager:us-east-1:123456789012:secret:" + "sagemaker/remote-function/evil/hmac-key/../sagemaker/" + "remote-function/legit-job/hmac-key-AbCdEf" + ) + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(malicious_arn) + + def test_rejects_arn_exceeding_max_length(self): + """Long input caused ReDoS — now rejected by length check.""" + long_arn = ( + "arn:aws:secretsmanager:us-east-1:123456789012:" + + ":secret:sagemaker/remote-function/y" * 10100 + + "\n:secret:sagemaker/remote-function/c/hmac-key-AbCdEf" + ) + with pytest.raises(DeserializationError, match="exceeds maximum length"): + _extract_job_name_from_secret_arn(long_arn) + + def test_rejects_arn_without_6char_suffix(self): + """ARN must end with hmac-key-XXXXXX (6 alphanumeric chars).""" + bad_arn = "arn:aws:secretsmanager:us-west-2:123456789012:secret:sagemaker/remote-function/job/hmac-key" + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(bad_arn) + + def test_rejects_arn_with_trailing_content(self): + """$ anchor prevents matching when extra content follows.""" + bad_arn = MOCK_SECRET_ARN + "/extra" + with pytest.raises(DeserializationError, match="does not match expected format"): + _extract_job_name_from_secret_arn(bad_arn) + """Tests for integrity check with HMAC.""" + + def test_hmac_integrity_check_passes(self): + """Valid HMAC should pass integrity check.""" + session, _, _, _ = _mock_sagemaker_session() + + payload = b"test payload" + expected_hmac = _compute_hmac(payload, MOCK_HMAC_KEY) + + # Should not raise + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + ) + + def test_hmac_integrity_check_fails_on_tampered_payload(self): + """Tampered payload should fail HMAC check.""" + session, _, _, _ = _mock_sagemaker_session() + + original_payload = b"original payload" + tampered_payload = b"tampered payload" + expected_hmac = _compute_hmac(original_payload, MOCK_HMAC_KEY) + + with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=tampered_payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + ) + + def test_legacy_sha256_check_rejected(self): + """Legacy SHA-256 check without secret_arn is no longer supported.""" + payload = b"test payload" + expected_hash = _compute_hash(payload) + + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=payload, + ) + + def test_legacy_sha256_tampered_payload_also_rejected(self): + """Legacy path is rejected regardless of hash correctness.""" + payload = b"test payload" + expected_hash = _compute_hash(payload) + + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=b"tampered", + ) + + def test_hmac_check_requires_session(self): + """HMAC check should require sagemaker_session.""" + with pytest.raises(DeserializationError, match="sagemaker_session is required"): + _perform_integrity_check( + expected_hash_value="hash", + buffer=b"data", + secret_arn=MOCK_SECRET_ARN, + ) + +class TestAttackScenarios: + """Tests simulating actual attack scenarios.""" + + def test_attacker_replaces_payload_and_metadata_plain_hash(self): + """Attacker replaces both files with plain SHA-256 (no secret_arn) - should be rejected.""" + malicious_payload = b"malicious code" + plain_hash = hashlib.sha256(malicious_payload).hexdigest() + + with pytest.raises(DeserializationError, match="HMAC integrity check is required"): + _perform_integrity_check( + expected_hash_value=plain_hash, + buffer=malicious_payload, + ) + + def test_attacker_points_to_cross_account_secret(self): + """Attacker points to their own secret in different account - should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + attacker_secret_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:attacker-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_secret_arn) + + def test_attacker_creates_secret_in_same_account(self): + """Attacker creates secret in same account but ARN doesn't match Parameter Store.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store has the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's secret in same account (with valid suffix format) + attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key-XyZ123" + + with pytest.raises(DeserializationError, match="Secret ARN mismatch"): + _validate_secret_arn(session, attacker_arn)