Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions python/cuquantum/tensornet/circuit_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, circuit, *, dtype='complex128', backend="auto", options=None)
circuit, self.backend_name, self.dtype, check_diagonal=self.check_diagonal, decompose_gates=self.decompose_gates)
self.n_qubits = len(self.qubits)
self._metadata = None
self._forward_inverse_metadata_cache = {}

@property
def qubits(self):
Expand Down Expand Up @@ -434,21 +435,55 @@ def _get_forward_inverse_metadata(self, lightcone, coned_qubits):
- ``next_frontier``: The next mode label to use.
- ``inverse_gates``: A sequence of (operand, qubits) for the inverse circuit.
"""
parser = self.parser
if lightcone:
circuit = parser.get_lightcone_circuit(self.circuit, coned_qubits)
_, gates, gates_are_diagonal = parser.unfold_circuit(circuit, self.backend_name, self.dtype, decompose_gates=self.decompose_gates, check_diagonal=self.check_diagonal)
# in cirq, the lightcone circuit may only contain a subset of the original qubits
# It's imperative to use qubits=self.qubits to generate the input tensors
input_mode_labels, input_operands, qubits_frontier = circ_utils.parse_inputs(self.qubits, gates, gates_are_diagonal, self.dtype, self.backend_name)
else:
circuit = self.circuit
input_mode_labels, input_operands, qubits_frontier = self._get_inputs()
# avoid inplace modification on metadata
qubits_frontier = qubits_frontier.copy()

next_frontier = max(qubits_frontier.values()) + 1
# inverse circuit
inverse_circuit = parser.get_inverse_circuit(circuit)
_, inverse_gates, inverse_gates_diagonals = parser.unfold_circuit(inverse_circuit, self.backend_name, self.dtype, decompose_gates=self.decompose_gates, check_diagonal=self.check_diagonal)
return input_mode_labels, input_operands, qubits_frontier, next_frontier, inverse_gates, inverse_gates_diagonals
coned_qubits = tuple(coned_qubits)
coned_qubits_set = set(coned_qubits)
coned_qubits_key = tuple(i for i, qubit in enumerate(self.qubits) if qubit in coned_qubits_set)
cache_key = (lightcone, coned_qubits_key, self.decompose_gates, self.check_diagonal)

cached_metadata = self._forward_inverse_metadata_cache.get(cache_key)
if cached_metadata is None:
parser = self.parser
if lightcone:
circuit = parser.get_lightcone_circuit(self.circuit, coned_qubits)
_, gates, gates_are_diagonal = parser.unfold_circuit(
circuit,
self.backend_name,
self.dtype,
decompose_gates=self.decompose_gates,
check_diagonal=self.check_diagonal,
)
# in cirq, the lightcone circuit may only contain a subset of the original qubits
# It's imperative to use qubits=self.qubits to generate the input tensors
input_mode_labels, input_operands, qubits_frontier = circ_utils.parse_inputs(
self.qubits,
gates,
gates_are_diagonal,
self.dtype,
self.backend_name,
)
else:
circuit = self.circuit
input_mode_labels, input_operands, qubits_frontier = self._get_inputs()

next_frontier = max(qubits_frontier.values()) + 1
# inverse circuit
inverse_circuit = parser.get_inverse_circuit(circuit)
_, inverse_gates, inverse_gates_diagonals = parser.unfold_circuit(
inverse_circuit,
self.backend_name,
self.dtype,
decompose_gates=self.decompose_gates,
check_diagonal=self.check_diagonal,
)
cached_metadata = (
input_mode_labels,
input_operands,
qubits_frontier.copy(),
next_frontier,
inverse_gates,
inverse_gates_diagonals,
)
self._forward_inverse_metadata_cache[cache_key] = cached_metadata

input_mode_labels, input_operands, qubits_frontier, next_frontier, inverse_gates, inverse_gates_diagonals = cached_metadata
return input_mode_labels, input_operands, qubits_frontier.copy(), next_frontier, inverse_gates, inverse_gates_diagonals
120 changes: 120 additions & 0 deletions python/samples/tensornet/circuit_to_einsum_cache_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES
#
# SPDX-License-Identifier: BSD-3-Clause

import argparse
import statistics
import time

from cuquantum.tensornet import CircuitToEinsum


def build_qiskit_circuit(num_qubits, depth):
from qiskit import QuantumCircuit

qc = QuantumCircuit(num_qubits)
for layer in range(depth):
for qubit in range(num_qubits):
angle = 0.1 * (layer + 1) * (qubit + 1)
qc.ry(angle, qubit)
qc.rz(angle / 2, qubit)
for qubit in range(num_qubits - 1):
qc.cx(qubit, qubit + 1)
return qc


def build_cirq_circuit(num_qubits, depth):
import cirq

qubits = cirq.LineQubit.range(num_qubits)
circuit = cirq.Circuit()
for layer in range(depth):
for qubit in qubits:
angle = 0.1 * (layer + 1) * (qubit.x + 1)
circuit.append(cirq.ry(angle)(qubit))
circuit.append(cirq.rz(angle / 2)(qubit))
for left, right in zip(qubits, qubits[1:]):
circuit.append(cirq.CNOT(left, right))
return circuit


def build_circuit(framework, num_qubits, depth):
if framework == "qiskit":
return build_qiskit_circuit(num_qubits, depth)
if framework == "cirq":
return build_cirq_circuit(num_qubits, depth)
raise ValueError(f"unsupported framework: {framework}")


def benchmark_expectation(converter, repetitions, lightcone):
qubits = converter.qubits
active_qubits = qubits[: min(3, len(qubits))]
pauli_map = {qubit: ("Z" if i % 2 == 0 else "X") for i, qubit in enumerate(active_qubits)}

timings = []
for _ in range(repetitions):
start = time.perf_counter()
converter.expectation(pauli_map, lightcone=lightcone)
timings.append(time.perf_counter() - start)
return timings


def benchmark_rdm(converter, repetitions, lightcone):
qubits = converter.qubits[: min(3, len(converter.qubits))]
timings = []
for _ in range(repetitions):
start = time.perf_counter()
converter.reduced_density_matrix(qubits, lightcone=lightcone)
timings.append(time.perf_counter() - start)
return timings


def summarize_timings(name, timings):
cold = timings[0]
warm = timings[1:] if len(timings) > 1 else timings
warm_mean = statistics.mean(warm)
speedup = cold / warm_mean if warm_mean else float("inf")

print(f"{name}:")
print(f" cold call: {cold:.6f} s")
print(f" warm mean: {warm_mean:.6f} s")
print(f" warm min: {min(warm):.6f} s")
print(f" warm max: {max(warm):.6f} s")
print(f" speedup: {speedup:.2f}x")


def main():
parser = argparse.ArgumentParser(description="Benchmark CircuitToEinsum metadata cache reuse.")
parser.add_argument("--framework", choices=("qiskit", "cirq"), default="qiskit")
parser.add_argument("--qubits", type=int, default=12)
parser.add_argument("--depth", type=int, default=20)
parser.add_argument("--repetitions", type=int, default=20)
parser.add_argument("--dtype", default="complex128")
parser.add_argument("--backend", default="numpy")
args = parser.parse_args()

circuit = build_circuit(args.framework, args.qubits, args.depth)
converter = CircuitToEinsum(circuit, dtype=args.dtype, backend=args.backend)

print("CircuitToEinsum metadata cache benchmark")
print(f"framework: {args.framework}")
print(f"qubits: {args.qubits}")
print(f"depth: {args.depth}")
print(f"repetitions: {args.repetitions}")
print(f"dtype: {args.dtype}")
print(f"backend: {args.backend}")
print()

expectation_lightcone = benchmark_expectation(converter, args.repetitions, lightcone=True)
expectation_no_lightcone = benchmark_expectation(converter, args.repetitions, lightcone=False)
rdm_lightcone = benchmark_rdm(converter, args.repetitions, lightcone=True)
rdm_no_lightcone = benchmark_rdm(converter, args.repetitions, lightcone=False)

summarize_timings("expectation(lightcone=True)", expectation_lightcone)
summarize_timings("expectation(lightcone=False)", expectation_no_lightcone)
summarize_timings("reduced_density_matrix(lightcone=True)", rdm_lightcone)
summarize_timings("reduced_density_matrix(lightcone=False)", rdm_no_lightcone)


if __name__ == "__main__":
main()
53 changes: 53 additions & 0 deletions python/tests/cuquantum_tests/tensornet/test_circuit_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,59 @@ def test_batched_amplitudes_marginal_cases(self, circuit):
assert expr == expr1
for o1, o2 in zip(operands, operands1):
assert np.allclose(o1, o2)

@pytest.mark.parametrize("circuit", CircuitMatrix.L1())
def test_forward_inverse_metadata_cache_reuses_identical_call(self, circuit):
converter = CircuitToEinsum(circuit, dtype='complex64', backend="numpy")
qubits = converter.qubits
if len(qubits) < 2:
pytest.skip("requires at least two qubits")

pauli_map = {qubits[0]: 'Z', qubits[1]: 'X'}
converter.expectation(pauli_map, lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

converter.expectation(pauli_map, lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

@pytest.mark.parametrize("circuit", CircuitMatrix.L1())
def test_forward_inverse_metadata_cache_normalizes_where_order(self, circuit):
converter = CircuitToEinsum(circuit, dtype='complex64', backend="numpy")
qubits = converter.qubits
if len(qubits) < 2:
pytest.skip("requires at least two qubits")

converter.reduced_density_matrix((qubits[0], qubits[1]), lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

converter.reduced_density_matrix((qubits[1], qubits[0]), lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

@pytest.mark.parametrize("circuit", CircuitMatrix.L1())
def test_forward_inverse_metadata_cache_separates_distinct_supports(self, circuit):
converter = CircuitToEinsum(circuit, dtype='complex64', backend="numpy")
qubits = converter.qubits
if len(qubits) < 2:
pytest.skip("requires at least two qubits")

converter.expectation({qubits[0]: 'Z'}, lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

converter.expectation({qubits[0]: 'Z', qubits[1]: 'Z'}, lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 2

@pytest.mark.parametrize("circuit", CircuitMatrix.L1())
def test_forward_inverse_metadata_cache_separates_lightcone_modes(self, circuit):
converter = CircuitToEinsum(circuit, dtype='complex64', backend="numpy")
qubits = converter.qubits
if len(qubits) < 1:
pytest.skip("requires at least one qubit")

converter.expectation({qubits[0]: 'Z'}, lightcone=True)
assert len(converter._forward_inverse_metadata_cache) == 1

converter.expectation({qubits[0]: 'Z'}, lightcone=False)
assert len(converter._forward_inverse_metadata_cache) == 2

@pytest.mark.parametrize("backend", ARRAY_BACKENDS)
@pytest.mark.parametrize("dtype", ("float32", "float64", "complex64", "complex128"))
Expand Down