AWS SageMaker¶
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Any | None = None, docker_client: DockerClient | None = None)[source]¶
Bases:
DockerWorkspaceMixin,Scheduler[Opts]AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.
$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello aws_batch://torchx_user/1234 $ torchx status aws_batch://torchx_user/1234 ...
Authentication is loaded from the environment using the
boto3credential handling.Config Options
usage: role=ROLE,instance_type=INSTANCE_TYPE,[instance_count=INSTANCE_COUNT],[keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS],[volume_size=VOLUME_SIZE],[volume_kms_key=VOLUME_KMS_KEY],[max_run=MAX_RUN],[input_mode=INPUT_MODE],[output_path=OUTPUT_PATH],[output_kms_key=OUTPUT_KMS_KEY],[base_job_name=BASE_JOB_NAME],[tags=TAGS],[subnets=SUBNETS],[security_group_ids=SECURITY_GROUP_IDS],[metric_definitions=METRIC_DEFINITIONS],[encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC],[use_spot_instances=USE_SPOT_INSTANCES],[max_wait=MAX_WAIT],[checkpoint_s3_uri=CHECKPOINT_S3_URI],[checkpoint_local_path=CHECKPOINT_LOCAL_PATH],[enable_network_isolation=ENABLE_NETWORK_ISOLATION],[environment=ENVIRONMENT],[max_retry_attempts=MAX_RETRY_ATTEMPTS],[source_dir=SOURCE_DIR],[hyperparameters=HYPERPARAMETERS],[training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE],[training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN],[disable_output_compression=DISABLE_OUTPUT_COMPRESSION],[enable_infra_check=ENABLE_INFRA_CHECK],[image_repo=IMAGE_REPO],[quiet=QUIET] required arguments: role=ROLE (str) An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_type=INSTANCE_TYPE (str) Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. optional arguments: instance_count=INSTANCE_COUNT (int, 1) Number of Amazon EC2 instances to use for training. Required if instance_groups is not set. keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None) The duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs. volume_size=VOLUME_SIZE (int, None) Size in GB of the storage volume to use for storing input and output data during training (default: 30). volume_kms_key=VOLUME_KMS_KEY (str, None) KMS key ID for encrypting EBS volume attached to the training instance. max_run=MAX_RUN (int, None) Timeout in seconds for training (default: 24 * 60 * 60). input_mode=INPUT_MODE (str, None) The input mode that the algorithm supports (default: 'File'). output_path=OUTPUT_PATH (str, None) S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. output_kms_key=OUTPUT_KMS_KEY (str, None) KMS key ID for encrypting the training output (default: Your IAM role's KMS key for Amazon S3). base_job_name=BASE_JOB_NAME (str, None) Prefix for training job name when the train() method launches. If not specified, the trainer generates a default job name based on the training image name and current timestamp. tags=TAGS (dict, None) Dictionary of tags for labeling a training job (e.g., key1:val1,key2:val2). subnets=SUBNETS (list, None) List of subnet ids. If not specified training job will be created without VPC config. security_group_ids=SECURITY_GROUP_IDS (list, None) List of security group ids. If not specified training job will be created without VPC config. metric_definitions=METRIC_DEFINITIONS (dict, None) Dictionary that defines the metric(s) used to evaluate the training jobs. Each key is the metric name and the value is the regular expression used to extract the metric from the logs (e.g., metric_name:regex_pattern,other_metric:other_regex). encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None) Specifies whether traffic between training containers is encrypted for the training job (default: False). use_spot_instances=USE_SPOT_INSTANCES (bool, None) Specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set. max_wait=MAX_WAIT (int, None) Timeout in seconds waiting for spot training job. checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None) S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None) Local path that the algorithm writes its checkpoints to. enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None) Specifies whether container will run in network isolation mode (default: False). environment=ENVIRONMENT (dict, None) Environment variables to be set for use during training job. max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None) Number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts. source_dir=SOURCE_DIR (str, None) Absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory). hyperparameters=HYPERPARAMETERS (dict, None) Dictionary containing the hyperparameters to initialize this estimator with. training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None) Specifies how SageMaker accesses the Docker image that contains the training algorithm. training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None) Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted. disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None) When set to true, Model is uploaded to Amazon S3 without compression after training finishes. enable_infra_check=ENABLE_INFRA_CHECK (bool, None) Specifies whether it is running Sagemaker built-in infra check jobs. image_repo=IMAGE_REPO (str, None) (remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container quiet=QUIET (bool, False) whether to suppress verbose output for image building. Defaults to ``False``.Compatibility
Feature
Scheduler Support
Fetch Logs
❌
Distributed Jobs
✔️
Cancel Job
✔️
Describe Job
Partial support. SageMakerScheduler will return job and replica status but does not provide the complete original AppSpec.
Workspaces / Patching
✔️
Mounts
❌
Elasticity
❌
- describe(app_id: str) torchx.schedulers.api.DescribeAppResponse | None[source]¶
Returns app description, or
Noneif it no longer exists.
- list(cfg: Optional[Mapping[str, str | int | float | bool | list[str] | dict[str, str] | None]] = None) list[torchx.schedulers.api.ListAppResponse][source]¶
Lists jobs on this scheduler.
- log_iter(app_id: str, role_name: str, k: int = 0, regex: str | None = None, since: datetime.datetime | None = None, until: datetime.datetime | None = None, should_tail: bool = False, streams: torchx.schedulers.api.Stream | None = None) Iterable[str][source]¶
Returns an iterator over log lines for the
k-th replica ofrole_name.Important
Not all schedulers support log iteration, tailing, or time-based cursors. Check the specific scheduler docs.
Lines include trailing whitespace (
\n). Whenshould_tail=True, the iterator blocks until the app reaches a terminal state.- Parameters:
k – replica (node) index
regex – optional filter pattern
since – start cursor (scheduler-dependent)
until – end cursor (scheduler-dependent)
should_tail – if
True, follow output liketail -fstreams –
stdout,stderr, orcombined
- Raises:
NotImplementedError – if the scheduler does not support log iteration
- schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) str[source]¶
Submits a previously dry-run request. Returns the app_id.
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: str, job_def: dict[str, Any], images_to_push: dict[str, tuple[str, str]])[source]¶
Jobs defined the key values that is required to schedule a job. This will be the value of request in the AppDryRunInfo object.
job_name: defines the job name shown in SageMaker
job_def: defines the job description that will be used to schedule the job on SageMaker
images_to_push: used by torchx to push to image_repo
Reference¶
- torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler[source]¶