Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/braket/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AngledGate, # noqa: F401
DoubleAngledGate, # noqa: F401
)
from braket.circuits.circuit import Circuit # noqa: F401
from braket.circuits.circuit import Circuit, QubitMatch # noqa: F401
from braket.circuits.circuit_diagram import CircuitDiagram # noqa: F401
from braket.circuits.compiler_directive import CompilerDirective # noqa: F401
from braket.circuits.free_parameter import FreeParameter # noqa: F401
Expand Down
93 changes: 92 additions & 1 deletion src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import warnings
from collections import Counter
from collections.abc import Callable, Iterable, Sequence
from enum import StrEnum
from numbers import Number
from typing import Any, TypeVar
from typing import Any, TypeVar, Union

import numpy as np
import oqpy
Expand Down Expand Up @@ -45,6 +46,7 @@
)
from braket.circuits.observable import Observable, euler_angle_parameter_names
from braket.circuits.observables import Sum, TensorProduct
from braket.circuits.operator import Operator
from braket.circuits.parameterizable import Parameterizable
from braket.circuits.result_type import (
ObservableParameterResultType,
Expand Down Expand Up @@ -73,6 +75,16 @@
AddableTypes = TypeVar("AddableTypes", SubroutineReturn, SubroutineCallable)


class QubitMatch(StrEnum):
"""Controls how multiple qubits are matched in count."""

ANY = "ANY"
ALL = "ALL"


OperatorIdentifier = Union[str, type[Operator], Operator]


class Circuit:
"""A representation of a quantum circuit that contains the instructions to be performed on a
quantum device and the requested result types.
Expand Down Expand Up @@ -243,6 +255,85 @@ def parameters(self) -> set[FreeParameter]:
"""
return self._parameters

@staticmethod
def _normalize_operator_name(identifier: OperatorIdentifier) -> str:
if isinstance(identifier, type):
return identifier.__name__.upper()
if isinstance(identifier, str):
return identifier.upper()
return identifier.name.upper()

@staticmethod
def _to_operator_names(
operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None,
) -> list[str]:
if operators is None:
return []
if isinstance(operators, (str, type)) or isinstance(operators, Operator):
return [Circuit._normalize_operator_name(operators)]
return [Circuit._normalize_operator_name(op) for op in operators]

def count(
self,
operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None = None,
qubits: QubitInput | Iterable[QubitInput] | None = None,
qubit_match: QubitMatch = QubitMatch.ANY,
include_types: Iterable[MomentType] = (MomentType.GATE,),
) -> Counter[str]:
"""
Count instructions in the circuit with optional filtering.

When both ``operators`` and ``qubits`` are specified, an instruction must satisfy
both filters to be counted (AND semantics).

Args:
operators: Filter by operator name or type. Defaults to None (no filter).
qubits: Filter by qubit. Matched against the union of target and control qubits.
qubit_match (QubitMatch): How multiple qubits relate. ANY = instruction on
any specified qubit; ALL = instruction on all specified qubits. Default ANY.
include_types (Iterable[MomentType]): Moment types to count. Default: GATE only.
Pass additional MomentType values to include noise, measures, etc.

Returns:
Counter[str]: Operator names mapped to occurrence counts.

Examples:
>>> circ = Circuit().h(0).cnot(0, 1).rx(0, 0.5)
>>> circ.count()
Counter({'H': 1, 'CNot': 1, 'Rx': 1})
>>> circ.count("h")
Counter({'H': 1})
>>> circ.count(["H", "CNot"])
Counter({'H': 1, 'CNot': 1})
>>> circ.count(qubits=0)
Counter({'H': 1, 'CNot': 1, 'Rx': 1})
"""
include_types_set = set(include_types)
operator_names_set = set(self._to_operator_names(operators))
_qs = QubitSet(qubits) if qubits is not None else None
filter_qubits = _qs if _qs else None # empty QubitSet treated as no filter

result: Counter[str] = Counter()

for key, instruction in self.moments.items():
if key.moment_type not in include_types_set:
continue

instr_qubits = instruction.target.union(instruction.control)
instr_name_upper = instruction.operator.name.upper()

qubit_pass = filter_qubits is None or (
any(q in instr_qubits for q in filter_qubits)
if qubit_match == QubitMatch.ANY
else all(q in instr_qubits for q in filter_qubits)
)
operator_pass = not operator_names_set or instr_name_upper in operator_names_set

if qubit_pass and operator_pass:
result[instruction.operator.name] += 1

return result

def with_euler_angles(self, observables: Sequence[Observable] | Sum) -> Circuit:
"""Returns a copy of the circuit with parametrized Euler angles on the observables' qubits

Expand Down
189 changes: 189 additions & 0 deletions test/unit_tests/braket/circuits/test_circuit_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from collections import Counter

import pytest

from braket.circuits import Circuit, gates
from braket.circuits.circuit import QubitMatch
from braket.circuits.moments import MomentType
from braket.circuits.noises import BitFlip


@pytest.fixture
def mixed_circuit():
return Circuit().h(0).cnot(0, 1).rx(0, 0.5).h(1)


def test_qubit_match_values():
assert QubitMatch.ANY == "ANY"
assert QubitMatch.ALL == "ALL"


def test_qubit_match_accepts_string_literal():
assert "ANY" == QubitMatch.ANY
assert "ALL" == QubitMatch.ALL


def test_no_filters_returns_all_gates(mixed_circuit):
result = mixed_circuit.count()
assert result == Counter({"H": 2, "CNot": 1, "Rx": 1})


def test_no_filters_empty_circuit():
result = Circuit().count()
assert result == Counter()


def test_default_include_types_excludes_gate_noise():
circ = Circuit().h(0)
circ.apply_gate_noise(BitFlip(0.1))
result = circ.count()
assert result == Counter({"H": 1})
assert "BitFlip" not in result


def test_include_gate_noise_type():
circ = Circuit().h(0)
circ.apply_gate_noise(BitFlip(0.1))
result = circ.count(include_types=[MomentType.GATE, MomentType.GATE_NOISE])
assert result["H"] == 1
assert result["BitFlip"] == 1


def test_operator_filter_uppercase_string(mixed_circuit):
result = mixed_circuit.count(operators="H")
assert result == Counter({"H": 2})


def test_operator_filter_lowercase_string(mixed_circuit):
result = mixed_circuit.count(operators="h")
assert result == Counter({"H": 2})


def test_operator_filter_mixed_case_string(mixed_circuit):
result = mixed_circuit.count(operators="cNoT")
assert result == Counter({"CNot": 1})


def test_operator_filter_gate_class(mixed_circuit):
result = mixed_circuit.count(operators=gates.CNot)
assert result == Counter({"CNot": 1})


def test_operator_filter_gate_instance(mixed_circuit):
result = mixed_circuit.count(operators=gates.CNot())
assert result == Counter({"CNot": 1})


def test_operator_filter_multiple_or(mixed_circuit):
result = mixed_circuit.count(operators=["H", "CNot"])
assert result == Counter({"H": 2, "CNot": 1})


def test_operator_filter_multiple_mixed_identifiers(mixed_circuit):
result = mixed_circuit.count(operators=["h", gates.CNot])
assert result == Counter({"H": 2, "CNot": 1})


def test_operator_filter_unknown_name_returns_empty(mixed_circuit):
result = mixed_circuit.count(operators="ZZZ")
assert result == Counter()


def test_operators_empty_list_is_no_filter(mixed_circuit):
assert mixed_circuit.count(operators=[]) == mixed_circuit.count()


def test_qubits_empty_list_is_no_filter(mixed_circuit):
assert mixed_circuit.count(qubits=[]) == mixed_circuit.count()


def test_operator_empty_list_and_qubits_empty_list_is_no_filter(mixed_circuit):
assert mixed_circuit.count(
operators=[], qubits=[]
) == mixed_circuit.count()


def test_qubit_filter_single_qubit():
circ = Circuit().h(0).cnot(0, 1).rx(1, 0.5)
result = circ.count(qubits=0)
assert result == Counter({"H": 1, "CNot": 1})


def test_qubit_filter_multiple_qubits_any():
circ = Circuit().h(0).h(2).cnot(0, 1)
result = circ.count(qubits=[0, 2])
assert result == Counter({"H": 2, "CNot": 1})


def test_qubit_filter_multiple_qubits_all():
circ = Circuit().h(0).h(1).cnot(0, 1)
result = circ.count(qubits=[0, 1], qubit_match=QubitMatch.ALL)
assert result == Counter({"CNot": 1})


def test_qubit_filter_no_match():
circ = Circuit().h(0).h(1)
result = circ.count(qubits=5)
assert result == Counter()


def test_qubit_filter_qubit_not_in_circuit():
circ = Circuit().h(0).cnot(0, 1)
result = circ.count(qubits=99)
assert result == Counter()


def test_qubit_filter_partial_qubits_not_in_circuit_any():
circ = Circuit().h(0).cnot(0, 1)
result = circ.count(qubits=[0, 99])
assert result == Counter({"H": 1, "CNot": 1})


def test_qubit_filter_partial_qubits_not_in_circuit_all():
circ = Circuit().h(0).cnot(0, 1)
result = circ.count(qubits=[0, 99], qubit_match=QubitMatch.ALL)
assert result == Counter()


def test_both_filters_counts_intersection():
circ = Circuit().h(0).rx(1, 0.5).cnot(0, 1)
result = circ.count(operators="CNot", qubits=[1])
assert result == Counter({"CNot": 1})


def test_both_filters_no_overlap():
circ = Circuit().h(0).rx(1, 0.5)
result = circ.count(operators="H", qubits=[1])
assert result == Counter()


def test_only_qubit_filter_no_operator_filter():
circ = Circuit().h(0).cnot(0, 1)
result = circ.count(qubits=[0])
assert result == Counter({"H": 1, "CNot": 1})


def test_only_operator_filter_no_qubit_filter():
circ = Circuit().h(0).h(1)
result = circ.count(operators="H")
assert result == Counter({"H": 2})


def test_circuit_method_positional_single_operator(mixed_circuit):
assert mixed_circuit.count("H") == Counter({"H": 2})


def test_circuit_method_positional_list_of_operators(mixed_circuit):
assert mixed_circuit.count(["H", "CNot"]) == Counter({"H": 2, "CNot": 1})


def test_circuit_method_passes_kwargs(mixed_circuit):
expected = mixed_circuit.count(operators="H", qubits=[0])
actual = mixed_circuit.count(operators="H", qubits=[0])
assert expected == actual


def test_qubit_match_importable_from_braket_circuits():
from braket.circuits import QubitMatch as QM

assert QM.ANY == "ANY"
Loading