Source code for braket.jobs.hybrid_job

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import functools
import importlib.util
import inspect
import os
import re
import shutil
import sys
import tempfile
import warnings
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from logging import Logger, getLogger
from pathlib import Path
from types import CodeType, ModuleType
from typing import Any

import cloudpickle

from braket.aws.aws_session import AwsSession
from braket.jobs._entry_point_template import run_entry_point, symlink_input_data
from braket.jobs.config import (
    CheckpointConfig,
    InstanceConfig,
    OutputDataConfig,
    S3DataSourceConfig,
    StoppingCondition,
)
from braket.jobs.image_uris import Framework, built_in_images, retrieve_image
from braket.jobs.local.local_job_container_setup import _get_env_input_data
from braket.jobs.quantum_job import QuantumJob
from braket.jobs.quantum_job_creation import _generate_default_job_name

DEFAULT_INPUT_CHANNEL = "input"
INNER_FUNCTION_SOURCE_INPUT_CHANNEL = "_braket_job_decorator_inner_function_source"
INNER_FUNCTION_SOURCE_INPUT_FOLDER = "_inner_function_source_folder"


[docs] def hybrid_job( *, device: str | None, include_modules: str | ModuleType | Iterable[str | ModuleType] | None = None, dependencies: str | Path | list[str] | None = None, local: bool = False, job_name: str | None = None, image_uri: str | None = None, input_data: str | dict | S3DataSourceConfig | None = None, wait_until_complete: bool = False, instance_config: InstanceConfig | None = None, distribution: str | None = None, copy_checkpoints_from_job: str | None = None, checkpoint_config: CheckpointConfig | None = None, role_arn: str | None = None, stopping_condition: StoppingCondition | None = None, output_data_config: OutputDataConfig | None = None, aws_session: AwsSession | None = None, tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), quiet: bool | None = None, reservation_arn: str | None = None, ) -> Callable: """Defines a hybrid job by decorating the entry point function. The job will be created when the decorated function is called. The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an `AwsQuantumJob`. The following parameters will be ignored when running a job with `local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`, `copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`. Remarks: Hybrid jobs created using this decorator have limited access to the source code of functions defined outside of the decorated function. Functionality that depends on source code analysis may not work properly when referencing functions defined outside of the decorated function. Args: device (str | None): Device ARN of the QPU device that receives priority quantum task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs queue so that only one hybrid job is running at a time. The device string is accessible in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN". When using embedded simulators, you may provide the device argument as string of the form: "local:<provider>/<simulator_name>" or `None`. include_modules (str | ModuleType | Iterable[str | ModuleType] | None): Either a single module or module name or a list of module or module names referring to local modules to be included. Any references to members of these modules in the hybrid job algorithm code will be serialized as part of the algorithm code. Default: `[]` dependencies (str | Path | list[str] | None): Path (absolute or relative) to a requirements.txt file, or alternatively a list of strings, with each string being a `requirement specifier <https://pip.pypa.io/en/stable/reference/requirement-specifiers/ #requirement-specifiers>`_, to be used for the hybrid job. local (bool): Whether to use local mode for the hybrid job. Default: `False` job_name (str | None): A string that specifies the name with which the job is created. Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to f'{decorated-function-name}-{timestamp}'. image_uri (str | None): A str that specifies the ECR image to use for executing the job. `retrieve_image()` function may be used for retrieving the ECR image URIs for the containers supported by Braket. Default: `<Braket base image_uri>`. input_data (str | dict | S3DataSourceConfig | None): Information about the training data. Dictionary maps channel names to local paths or S3 URIs. Contents found at any local paths will be uploaded to S3 at f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}'. If a local path, S3 URI, or S3DataSourceConfig is provided, it will be given a default channel name "input". Default: {}. wait_until_complete (bool): `True` if we should wait until the job completes. This would tail the job logs as it waits. Otherwise `False`. Ignored if using local mode. Default: `False`. instance_config (InstanceConfig | None): Configuration of the instance(s) for running the classical code for the hybrid job. Default: `InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`. distribution (str | None): A str that specifies how the job should be distributed. If set to "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: `None`. copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose checkpoint you want to use in the current job. Specifying this value will copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's checkpoint_config s3Uri, making it available at checkpoint_config.localPath during the job execution. Default: `None` checkpoint_config (CheckpointConfig | None): Configuration that specifies the location where checkpoint data is stored. Default: `CheckpointConfig(localPath='/opt/jobs/checkpoints', s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints')`. role_arn (str | None): A str providing the IAM role ARN used to execute the script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. stopping_condition (StoppingCondition | None): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). output_data_config (OutputDataConfig | None): Specifies the location for the output of the job. Default: `OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', kmsKeyId=None)`. aws_session (AwsSession | None): AwsSession for connecting to AWS Services. Default: AwsSession() tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job. Default: {}. logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default: `getLogger(__name__)` quiet (bool | None): Sets the verbosity of the logger to low and does not report queue position. Default is `False`. reservation_arn (str | None): the reservation window arn provided by Braket Direct to reserve exclusive usage for the device to run the hybrid job on. Default: None. Returns: Callable: the callable for creating a Hybrid Job. """ _validate_python_version(image_uri, aws_session) def _hybrid_job(entry_point: Callable) -> Callable: @functools.wraps(entry_point) def job_wrapper(*args: Any, **kwargs: Any) -> Callable: """The job wrapper. Args: *args (Any): Arbitrary arguments. **kwargs (Any): Arbitrary keyword arguments. Returns: Callable: the callable for creating a Hybrid Job. """ with ( _IncludeModules(include_modules), tempfile.TemporaryDirectory(dir="", prefix="decorator_job_") as temp_dir, persist_inner_function_source(entry_point) as inner_source_input, ): job_input_data = _add_inner_function_source_to_input_data( input_data, inner_source_input ) temp_dir_path = Path(temp_dir) entry_point_file_path = Path("entry_point.py") with open( temp_dir_path / entry_point_file_path, "w", encoding="utf-8" ) as entry_point_file: template = "\n".join([ _process_input_data(input_data), _serialize_entry_point(entry_point, args, kwargs), ]) entry_point_file.write(template) if dependencies: _process_dependencies(dependencies, temp_dir_path) job_args = { "device": device or "local:none/none", "source_module": temp_dir, "entry_point": ( f"{temp_dir_path.name}.{entry_point_file_path.stem}:{entry_point.__name__}" ), "wait_until_complete": wait_until_complete, "job_name": job_name or _generate_default_job_name(func=entry_point), "hyperparameters": _log_hyperparameters(entry_point, args, kwargs), "logger": logger, } optional_args = { "image_uri": image_uri, "input_data": job_input_data, "instance_config": instance_config, "distribution": distribution, "checkpoint_config": checkpoint_config, "copy_checkpoints_from_job": copy_checkpoints_from_job, "role_arn": role_arn, "stopping_condition": stopping_condition, "output_data_config": output_data_config, "aws_session": aws_session, "tags": tags, "quiet": quiet, "reservation_arn": reservation_arn, } job_args.update({key: val for key, val in optional_args.items() if val is not None}) return _create_job(job_args, local) return job_wrapper return _hybrid_job
[docs] @contextmanager def persist_inner_function_source(entry_point: callable) -> None: """Persist the source code of the cloudpickled function by saving its source code as input data and replace the source file path with the saved one. Args: entry_point (callable): The job decorated function. Yields: dict: if the inner function exists, a mapping of the input channel to the copy directory. Otherwise an empty dict """ inner_source_mapping = _get_inner_function_source(entry_point.__code__) if len(inner_source_mapping) == 0: yield {} else: with tempfile.TemporaryDirectory(dir="", prefix="decorator_job_inner_source_") as temp_dir: copy_dir = f"{temp_dir}/{INNER_FUNCTION_SOURCE_INPUT_FOLDER}" os.mkdir(copy_dir) path_mapping = _save_inner_source_to_file(inner_source_mapping, copy_dir) entry_point.__code__ = _replace_inner_function_source_path( entry_point.__code__, path_mapping ) yield {INNER_FUNCTION_SOURCE_INPUT_CHANNEL: copy_dir}
def _replace_inner_function_source_path( code_object: CodeType, path_mapping: dict[str, str] ) -> CodeType: """Recursively replace source code file path of the code object and of its child node's code objects. Args: code_object (CodeType): Code object which source code file path to be replaced. path_mapping (dict[str, str]): Mapping between local file path to path in a job environment. Returns: CodeType: Code object with the source code file path replaced """ new_co_consts = [] for const in code_object.co_consts: new_const = const if inspect.iscode(const): new_path = path_mapping[const.co_filename] new_const = const.replace(co_filename=new_path) new_const = _replace_inner_function_source_path(new_const, path_mapping) new_co_consts.append(new_const) return code_object.replace(co_consts=tuple(new_co_consts)) def _save_inner_source_to_file(inner_source: dict[str, str], input_data_dir: str) -> dict[str, str]: """Saves the source code as input data for a job and returns a dictionary that maps the local source file path of a function to the one to be used in the job environment. Args: inner_source (dict[str, str]): Mapping between source file name and source code. input_data_dir (str): The path of the folder to be uploaded to job as input data. Returns: dict[str, str]: Mapping between local file path to path in a job environment. """ path_mapping = {} for i, (local_path, source_code) in enumerate(inner_source.items()): copy_file_name = f"source_{i}.py" with open(f"{input_data_dir}/{copy_file_name}", "w", encoding="utf-8") as f: f.write(source_code) path_mapping[local_path] = os.path.join( _get_env_input_data()["AMZN_BRAKET_INPUT_DIR"], INNER_FUNCTION_SOURCE_INPUT_CHANNEL, copy_file_name, ) return path_mapping def _get_inner_function_source(code_object: CodeType) -> dict[str, str]: """Returns a dictionary that maps the source file name to source code for all source files used by the inner functions inside the job decorated function. Args: code_object (CodeType): Code object of a inner function. Returns: dict[str, str]: Mapping between source file name and source code. """ inner_source = {} for const in code_object.co_consts: if inspect.iscode(const): source_file_path = inspect.getfile(code_object) lines, _ = inspect.findsource(code_object) inner_source.update({source_file_path: "".join(lines)}) inner_source.update(_get_inner_function_source(const)) return inner_source def _add_inner_function_source_to_input_data(input_data: dict, inner_source_input: dict) -> dict: """Add the path of inner function source file as the input data of the job. Args: input_data (dict): Provided input data of the job. inner_source_input (dict): A dict that points to the path of inner function source file. Returns: dict: input_data with inner function source file added. """ if input_data is None: job_input_data = inner_source_input elif isinstance(input_data, dict): if INNER_FUNCTION_SOURCE_INPUT_CHANNEL in input_data: raise ValueError(f"input channel cannot be {INNER_FUNCTION_SOURCE_INPUT_CHANNEL}") job_input_data = {**input_data, **inner_source_input} else: job_input_data = {DEFAULT_INPUT_CHANNEL: input_data, **inner_source_input} return job_input_data def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None) -> None: """Validate python version at job definition time""" aws_session = aws_session or AwsSession() # user provides a custom image_uri if image_uri and image_uri not in built_in_images(aws_session.region): print( "Skipping python version validation, make sure versions match " "between local environment and container." ) else: # set default image_uri to base image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region) tag = aws_session.get_full_image_tag(image_uri) major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups() if (sys.version_info.major, sys.version_info.minor) != ( int(major_version), int(minor_version), ): raise RuntimeError( "Python version must match between local environment and container. " f"Client is running Python {sys.version_info.major}.{sys.version_info.minor} " f"locally, but container uses Python {major_version}.{minor_version}." ) def _process_dependencies(dependencies: str | Path | list[str], temp_dir: Path) -> None: if isinstance(dependencies, str | Path): # requirements file shutil.copy(Path(dependencies).resolve(), temp_dir / "requirements.txt") else: # list of packages with open(temp_dir / "requirements.txt", "w", encoding="utf-8") as f: f.write("\n".join(dependencies)) class _IncludeModules: def __init__(self, modules: str | ModuleType | Iterable[str | ModuleType] | None = None): modules = modules or [] if isinstance(modules, str | ModuleType): modules = [modules] self._modules = [ (importlib.import_module(module) if isinstance(module, str) else module) for module in modules ] def __enter__(self): """Register included modules with cloudpickle to be pickled by value""" for module in self._modules: cloudpickle.register_pickle_by_value(module) def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 """Unregister included modules with cloudpickle to be pickled by value""" for module in self._modules: cloudpickle.unregister_pickle_by_value(module) def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str: """Create an entry point from a function""" wrapped_entry_point = functools.partial(entry_point, *args, **kwargs) try: serialized = cloudpickle.dumps(wrapped_entry_point) except Exception as e: raise RuntimeError( "Serialization failed for decorator hybrid job. If you are referencing " "an object from outside the function scope, either directly or through " "function parameters, try instantiating the object inside the decorated " "function instead." ) from e return run_entry_point.format( serialized=serialized, function_name=entry_point.__name__, ) def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> dict: """Capture function arguments as hyperparameters""" signature = inspect.signature(entry_point) bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() hyperparameters = {} for param, value in bound_args.arguments.items(): param_kind = signature.parameters[param].kind if param_kind in { inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, }: hyperparameters[param] = value elif param_kind == inspect.Parameter.VAR_KEYWORD: hyperparameters.update(**value) else: warnings.warn( "Positional only arguments will not be logged to the hyperparameters file.", stacklevel=1, ) return {name: _sanitize(value) for name, value in hyperparameters.items()} def _sanitize(hyperparameter: Any) -> str: """Sanitize forbidden characters from hp strings""" string_hp = str(hyperparameter) sanitized = ( string_hp # replace forbidden characters with close matches .replace("\n", " ") .replace("$", "?") .replace("(", "{") .replace("&", "+") .replace("`", "'") # not technically forbidden, but to avoid mismatched parens .replace(")", "}") ) # max allowed length for a hyperparameter is 2500 if len(sanitized) > 2500: # show as much as possible, including the final 20 characters return f"{sanitized[: 2500 - 23]}...{sanitized[-20:]}" return sanitized def _process_input_data(input_data: dict) -> list[str]: """Create symlinks to data. Logic chart for how the service moves files into the data directory on the instance: input data matches exactly one file: cwd/filename -> channel/filename input data matches exactly one directory: cwd/dirname/* -> channel/* else (multiple matches, possibly including exact): cwd/prefix_match -> channel/prefix_match, for each match """ input_data = input_data or {} if not isinstance(input_data, dict): input_data = {"input": input_data} def matches(prefix: str) -> list[str]: return [str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(prefix)] def is_prefix(path: str) -> bool: return len(matches(path)) > 1 or not Path(path).exists() prefix_channels = set() directory_channels = set() file_channels = set() for channel, data in input_data.items(): if AwsSession.is_s3_uri(str(data)) or isinstance(data, S3DataSourceConfig): channel_arg = f'channel="{channel}"' if channel != "input" else "" print( "Input data channels mapped to an S3 source will not be available in " f"the working directory. Use `get_input_data_dir({channel_arg})` to read " f"input data from S3 source inside the job container." ) elif is_prefix(str(data)): prefix_channels.add(channel) elif Path(data).is_dir(): directory_channels.add(channel) else: file_channels.add(channel) return symlink_input_data.format( prefix_matches={channel: matches(input_data[channel]) for channel in prefix_channels}, input_data_items=[ (channel, data) for channel, data in input_data.items() if channel in prefix_channels | directory_channels | file_channels ], prefix_channels=prefix_channels, directory_channels=directory_channels, ) def _create_job(job_args: dict[str, Any], local: bool = False) -> QuantumJob: """Create an AWS or Local hybrid job""" if local: from braket.jobs.local import LocalQuantumJob # noqa: PLC0415 for aws_only_arg in [ "wait_until_complete", "copy_checkpoints_from_job", "instance_config", "distribution", "stopping_condition", "tags", "logger", ]: job_args.pop(aws_only_arg, None) return LocalQuantumJob.create(**job_args) from braket.aws import AwsQuantumJob # noqa: PLC0415 return AwsQuantumJob.create(**job_args)