diff --git a/src/braket/circuits/__init__.py b/src/braket/circuits/__init__.py index a2930ed79..5f5bdde21 100644 --- a/src/braket/circuits/__init__.py +++ b/src/braket/circuits/__init__.py @@ -28,7 +28,7 @@ AngledGate, # noqa: F401 DoubleAngledGate, # noqa: F401 ) -from braket.circuits.circuit import Circuit # noqa: F401 +from braket.circuits.circuit import Circuit, QubitMatch # noqa: F401 from braket.circuits.circuit_diagram import CircuitDiagram # noqa: F401 from braket.circuits.compiler_directive import CompilerDirective # noqa: F401 from braket.circuits.free_parameter import FreeParameter # noqa: F401 diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index b656763c3..b4ed94323 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -16,6 +16,7 @@ import warnings from collections import Counter from collections.abc import Callable, Iterable, Sequence +from enum import StrEnum from numbers import Number from typing import Any, TypeVar @@ -45,6 +46,7 @@ ) from braket.circuits.observable import Observable, euler_angle_parameter_names from braket.circuits.observables import Sum, TensorProduct +from braket.circuits.operator import Operator from braket.circuits.parameterizable import Parameterizable from braket.circuits.result_type import ( ObservableParameterResultType, @@ -73,6 +75,16 @@ AddableTypes = TypeVar("AddableTypes", SubroutineReturn, SubroutineCallable) +class QubitMatch(StrEnum): + """Controls how multiple qubits are matched in count.""" + + ANY = "ANY" + ALL = "ALL" + + +OperatorIdentifier = str | type[Operator] | Operator + + class Circuit: """A representation of a quantum circuit that contains the instructions to be performed on a quantum device and the requested result types. @@ -243,6 +255,85 @@ def parameters(self) -> set[FreeParameter]: """ return self._parameters + @staticmethod + def _normalize_operator_name(identifier: OperatorIdentifier) -> str: + if isinstance(identifier, type): + return identifier.__name__.upper() + if isinstance(identifier, str): + return identifier.upper() + return identifier.name.upper() + + @staticmethod + def _to_operator_names( + operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None, + ) -> list[str]: + if operators is None: + return [] + if isinstance(operators, (str, type, Operator)): + return [Circuit._normalize_operator_name(operators)] + return [Circuit._normalize_operator_name(op) for op in operators] + + def count( + self, + operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None = None, + qubits: QubitInput | Iterable[QubitInput] | None = None, + qubit_match: QubitMatch = QubitMatch.ANY, + include_types: Iterable[MomentType] = (MomentType.GATE,), + ) -> Counter[str]: + """ + Count instructions in the circuit with optional filtering. + + When both ``operators`` and ``qubits`` are specified, an instruction must satisfy + both filters to be counted (AND semantics). + + Args: + operators: Filter by operator name or type. Defaults to None (no filter). + qubits: Filter by qubit. Matched against the union of target and control qubits. + qubit_match (QubitMatch): How multiple qubits relate. ANY = instruction on + any specified qubit; ALL = instruction on all specified qubits. Default ANY. + include_types (Iterable[MomentType]): Moment types to count. Default: GATE only. + Pass additional MomentType values to include noise, measures, etc. + + Returns: + Counter[str]: Operator names mapped to occurrence counts. + + Examples: + >>> circ = Circuit().h(0).cnot(0, 1).rx(0, 0.5) + >>> circ.count() + Counter({'H': 1, 'CNot': 1, 'Rx': 1}) + >>> circ.count("h") + Counter({'H': 1}) + >>> circ.count(["H", "CNot"]) + Counter({'H': 1, 'CNot': 1}) + >>> circ.count(qubits=0) + Counter({'H': 1, 'CNot': 1, 'Rx': 1}) + """ + include_types_set = set(include_types) + operator_names_set = set(self._to_operator_names(operators)) + qs = QubitSet(qubits) if qubits is not None else None + filter_qubits = qs or None # empty QubitSet treated as no filter + + result: Counter[str] = Counter() + + for key, instruction in self.moments.items(): + if key.moment_type not in include_types_set: + continue + + instr_qubits = instruction.target.union(instruction.control) + instr_name_upper = instruction.operator.name.upper() + + qubit_pass = filter_qubits is None or ( + any(q in instr_qubits for q in filter_qubits) + if qubit_match == QubitMatch.ANY + else all(q in instr_qubits for q in filter_qubits) + ) + operator_pass = not operator_names_set or instr_name_upper in operator_names_set + + if qubit_pass and operator_pass: + result[instruction.operator.name] += 1 + + return result + def with_euler_angles(self, observables: Sequence[Observable] | Sum) -> Circuit: """Returns a copy of the circuit with parametrized Euler angles on the observables' qubits diff --git a/test/unit_tests/braket/circuits/test_circuit_analysis.py b/test/unit_tests/braket/circuits/test_circuit_analysis.py new file mode 100644 index 000000000..6d025df17 --- /dev/null +++ b/test/unit_tests/braket/circuits/test_circuit_analysis.py @@ -0,0 +1,71 @@ +from collections import Counter + +import pytest + +from braket.circuits import Circuit, gates +from braket.circuits.circuit import QubitMatch +from braket.circuits.moments import MomentType +from braket.circuits.noises import BitFlip + + +@pytest.fixture +def mixed_circuit(): + return Circuit().h(0).cnot(0, 1).rx(0, 0.5).h(1) + + +def test_no_filters_returns_all_gates(mixed_circuit): + assert mixed_circuit.count() == Counter({"H": 2, "CNot": 1, "Rx": 1}) + + +def test_operator_filter_multiple_mixed_identifiers(mixed_circuit): + assert mixed_circuit.count(operators=["h", gates.CNot]) == Counter({"H": 2, "CNot": 1}) + + +def test_operator_filter_gate_instance(mixed_circuit): + assert mixed_circuit.count(operators=gates.CNot()) == Counter({"CNot": 1}) + + +def test_include_gate_noise_type(): + circ = Circuit().h(0) + circ.apply_gate_noise(BitFlip(0.1)) + result = circ.count(include_types=[MomentType.GATE, MomentType.GATE_NOISE]) + assert result == Counter({"H": 1, "BitFlip": 1}) + + +def test_gate_noise_excluded_by_default(): + circ = Circuit().h(0) + circ.apply_gate_noise(BitFlip(0.1)) + assert circ.count() == Counter({"H": 1}) + + +def test_qubit_filter_single_qubit(mixed_circuit): + assert mixed_circuit.count(qubits=0) == Counter({"H": 1, "CNot": 1, "Rx": 1}) + + +def test_qubit_filter_multiple_qubits_any(): + circ = Circuit().h(0).h(2).cnot(0, 1) + assert circ.count(qubits=[0, 2]) == Counter({"H": 2, "CNot": 1}) + + +def test_qubit_filter_multiple_qubits_all(mixed_circuit): + assert mixed_circuit.count(qubits=[0, 1], qubit_match=QubitMatch.ALL) == Counter({"CNot": 1}) + + +def test_operator_and_qubit_filters_intersect(mixed_circuit): + assert mixed_circuit.count(operators="CNot", qubits=[1]) == Counter({"CNot": 1}) + + +def test_unknown_operator_returns_empty(mixed_circuit): + assert mixed_circuit.count(operators="ZZZ") == Counter() + + +def test_qubit_not_in_circuit_returns_empty(mixed_circuit): + assert mixed_circuit.count(qubits=99) == Counter() + + +def test_partial_qubits_not_in_circuit_any(mixed_circuit): + assert mixed_circuit.count(qubits=[0, 99]) == Counter({"H": 1, "CNot": 1, "Rx": 1}) + + +def test_partial_qubits_not_in_circuit_all(mixed_circuit): + assert mixed_circuit.count(qubits=[0, 99], qubit_match=QubitMatch.ALL) == Counter()