# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import annotations
from collections.abc import Iterable
from braket.emulation.passes.passes import _EmulatorPass
from braket.tasks.quantum_task import TaskSpecification
[docs]
class EmulatorValidationError(Exception):
"""Custom exception validation errors from emulators."""
[docs]
class PassManager:
def __init__(self, passes: _EmulatorPass | Iterable[_EmulatorPass] | None = None):
self._passes = list(passes) if isinstance(passes, Iterable) else [passes] if passes else []
[docs]
def validate(self, task_specification: TaskSpecification) -> None:
"""
This method passes the input program through Passes that perform
only validation, without modifying the input program.
Args:
task_specification (TaskSpecification): The program to validate with this
emulator's validation passes.
"""
try:
for emulator_pass in self._passes:
emulator_pass(task_specification)
except Exception as e:
self._raise_exception(e)
def _raise_exception(self, exception: Exception) -> None:
"""
Wrapper for exceptions enable modifyint the exception message if needed.
Args:
exception (Exception): The exception to modify and raise.
"""
raise EmulatorValidationError(str(exception)) from exception
def __iadd__(
self, passes: _EmulatorPass | PassManager | Iterable[_EmulatorPass]
) -> PassManager:
"""Incrementally add a pass, passmanager, or iterable pass"""
if isinstance(passes, PassManager):
self._passes.append(passes._passes)
elif isinstance(passes, Iterable):
for pass_ in passes:
self._passes.append(pass_)
else:
self._passes.append(passes)
return self
def __add__(self, passes: _EmulatorPass | PassManager | Iterable[_EmulatorPass]) -> PassManager:
"""add EmulatorPass, Passmanager, or iterable pass object to a PassManager"""
if isinstance(passes, PassManager):
passes_ = self._passes + passes._passes
elif isinstance(passes, _EmulatorPass):
passes_ = [*self._passes, passes]
else:
passes_ = [*self._passes, *passes]
return PassManager(passes_)
def __len__(self) -> int:
return len(self._passes)