Shortcuts

AWS SageMaker

class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Any | None = None, docker_client: DockerClient | None = None)[source]

Bases: DockerWorkspaceMixin, Scheduler[Opts]

AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.

$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello
aws_batch://torchx_user/1234
$ torchx status aws_batch://torchx_user/1234
...

Authentication is loaded from the environment using the boto3 credential handling.

Config Options

    usage:
        role=ROLE,instance_type=INSTANCE_TYPE,[instance_count=INSTANCE_COUNT],[keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS],[volume_size=VOLUME_SIZE],[volume_kms_key=VOLUME_KMS_KEY],[max_run=MAX_RUN],[input_mode=INPUT_MODE],[output_path=OUTPUT_PATH],[output_kms_key=OUTPUT_KMS_KEY],[base_job_name=BASE_JOB_NAME],[tags=TAGS],[subnets=SUBNETS],[security_group_ids=SECURITY_GROUP_IDS],[metric_definitions=METRIC_DEFINITIONS],[encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC],[use_spot_instances=USE_SPOT_INSTANCES],[max_wait=MAX_WAIT],[checkpoint_s3_uri=CHECKPOINT_S3_URI],[checkpoint_local_path=CHECKPOINT_LOCAL_PATH],[enable_network_isolation=ENABLE_NETWORK_ISOLATION],[environment=ENVIRONMENT],[max_retry_attempts=MAX_RETRY_ATTEMPTS],[source_dir=SOURCE_DIR],[hyperparameters=HYPERPARAMETERS],[training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE],[training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN],[disable_output_compression=DISABLE_OUTPUT_COMPRESSION],[enable_infra_check=ENABLE_INFRA_CHECK],[image_repo=IMAGE_REPO],[quiet=QUIET]

    required arguments:
        role=ROLE (str)
            An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
        instance_type=INSTANCE_TYPE (str)
            Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.

    optional arguments:
        instance_count=INSTANCE_COUNT (int, 1)
            Number of Amazon EC2 instances to use for training. Required if instance_groups is not set.
        keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None)
            The duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs.
        volume_size=VOLUME_SIZE (int, None)
            Size in GB of the storage volume to use for storing input and output data during training (default: 30).
        volume_kms_key=VOLUME_KMS_KEY (str, None)
            KMS key ID for encrypting EBS volume attached to the training instance.
        max_run=MAX_RUN (int, None)
            Timeout in seconds for training (default: 24 * 60 * 60).
        input_mode=INPUT_MODE (str, None)
            The input mode that the algorithm supports (default: 'File').
        output_path=OUTPUT_PATH (str, None)
            S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket.
        output_kms_key=OUTPUT_KMS_KEY (str, None)
            KMS key ID for encrypting the training output (default: Your IAM role's KMS key for Amazon S3).
        base_job_name=BASE_JOB_NAME (str, None)
            Prefix for training job name when the train() method launches. If not specified, the trainer generates a default job name based on the training image name and current timestamp.
        tags=TAGS (dict, None)
            Dictionary of tags for labeling a training job (e.g., key1:val1,key2:val2).
        subnets=SUBNETS (list, None)
            List of subnet ids. If not specified training job will be created without VPC config.
        security_group_ids=SECURITY_GROUP_IDS (list, None)
            List of security group ids. If not specified training job will be created without VPC config.
        metric_definitions=METRIC_DEFINITIONS (dict, None)
            Dictionary that defines the metric(s) used to evaluate the training jobs. Each key is the metric name and the value is the regular expression used to extract the metric from the logs (e.g., metric_name:regex_pattern,other_metric:other_regex).
        encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None)
            Specifies whether traffic between training containers is encrypted for the training job (default: False).
        use_spot_instances=USE_SPOT_INSTANCES (bool, None)
            Specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set.
        max_wait=MAX_WAIT (int, None)
            Timeout in seconds waiting for spot training job.
        checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None)
            S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.
        checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None)
            Local path that the algorithm writes its checkpoints to.
        enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None)
            Specifies whether container will run in network isolation mode (default: False).
        environment=ENVIRONMENT (dict, None)
            Environment variables to be set for use during training job.
        max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None)
            Number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts.
        source_dir=SOURCE_DIR (str, None)
            Absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory).
        hyperparameters=HYPERPARAMETERS (dict, None)
            Dictionary containing the hyperparameters to initialize this estimator with.
        training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None)
            Specifies how SageMaker accesses the Docker image that contains the training algorithm.
        training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None)
            Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted.
        disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None)
            When set to true, Model is uploaded to Amazon S3 without compression after training finishes.
        enable_infra_check=ENABLE_INFRA_CHECK (bool, None)
            Specifies whether it is running Sagemaker built-in infra check jobs.
        image_repo=IMAGE_REPO (str, None)
            (remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container
        quiet=QUIET (bool, False)
            whether to suppress verbose output for image building. Defaults to ``False``.

Compatibility

Feature

Scheduler Support

Fetch Logs

Distributed Jobs

✔️

Cancel Job

✔️

Describe Job

Partial support. SageMakerScheduler will return job and replica status but does not provide the complete original AppSpec.

Workspaces / Patching

✔️

Mounts

Elasticity

describe(app_id: str) torchx.schedulers.api.DescribeAppResponse | None[source]

Returns app description, or None if it no longer exists.

list(cfg: Optional[Mapping[str, str | int | float | bool | list[str] | dict[str, str] | None]] = None) list[torchx.schedulers.api.ListAppResponse][source]

Lists jobs on this scheduler.

log_iter(app_id: str, role_name: str, k: int = 0, regex: str | None = None, since: datetime.datetime | None = None, until: datetime.datetime | None = None, should_tail: bool = False, streams: torchx.schedulers.api.Stream | None = None) Iterable[str][source]

Returns an iterator over log lines for the k-th replica of role_name.

Important

Not all schedulers support log iteration, tailing, or time-based cursors. Check the specific scheduler docs.

Lines include trailing whitespace (\n). When should_tail=True, the iterator blocks until the app reaches a terminal state.

Parameters:
  • k – replica (node) index

  • regex – optional filter pattern

  • since – start cursor (scheduler-dependent)

  • until – end cursor (scheduler-dependent)

  • should_tail – if True, follow output like tail -f

  • streamsstdout, stderr, or combined

Raises:

NotImplementedError – if the scheduler does not support log iteration

schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) str[source]

Submits a previously dry-run request. Returns the app_id.

class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: str, job_def: dict[str, Any], images_to_push: dict[str, tuple[str, str]])[source]

Jobs defined the key values that is required to schedule a job. This will be the value of request in the AppDryRunInfo object.

  • job_name: defines the job name shown in SageMaker

  • job_def: defines the job description that will be used to schedule the job on SageMaker

  • images_to_push: used by torchx to push to image_repo

Reference

torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler[source]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources