Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions glue/cirq/stimcirq/_obs_annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Tuple

import cirq
import stim
Expand All @@ -16,7 +16,7 @@ def __init__(
*,
parity_keys: Iterable[str] = (),
relative_keys: Iterable[int] = (),
pauli_keys: Iterable[str] = (),
pauli_keys: Iterable[tuple[cirq.Qid, str]] | Iterable[str] = (),
observable_index: int,
):
"""
Expand All @@ -29,15 +29,31 @@ def __init__(
"""
self.parity_keys = frozenset(parity_keys)
self.relative_keys = frozenset(relative_keys)
self.pauli_keys = frozenset(pauli_keys)
_pauli_keys = []
for k in pauli_keys:
if isinstance(k, str):
# For backward compatibility
_pauli_keys.append((cirq.LineQubit(int(k[1:])), k[0]))
else:
_pauli_keys.append(tuple(k))
self.pauli_keys = frozenset(_pauli_keys)
Copy link
Collaborator

@Strilanc Strilanc Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a user has code for k in op.pauli_keys: if k.startswith("..."): ... ? This would break that.

Copy link
Collaborator Author

@AlexBourassa AlexBourassa Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you mean that changing the type of pauli_keys from str to tuple[cirq.QiD, str] is an issue? The problem I'm trying to solve is that there are no reference to the "stim qubit id" when we are in cirq, and there is no guarantee that these ids will mean anything useful when we convert the circuit back from stim to cirq.

We could keep the type str and do something like X(1,1) and then interpret k[1:] as a tuple/Qid. That would not impact the snippet you showed, but we could certainly come up with another example snippet that would break (for example if the user has code that does int(k[1:]) it would break...). I think this kind of string interpretation might be worse, but I'm game if you prefer it? I don't really see a clean way to achieve this change... But if you have some preferred format, I'm happy to implement it. Fwiw, since this is a pretty recent addition to stimcirq, I don't think it's gotten too much adoption yet, so making a change now is unlikely to impact anyone at this point.

Anyway, again I'm open to whatever you think is the best solution here. Let me know!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main thing is that you must make the changes in a way that doesn't break existing code. Existing code is likely to depend on the type of public fields, so you cannot change their type. You can add other fields, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I tried a version of this. Let me know what you think!

self.observable_index = observable_index

@property
def qubits(self) -> Tuple[cirq.Qid, ...]:
return ()
return tuple(sorted(q for q, _ in self.pauli_keys))

def with_qubits(self, *new_qubits) -> 'CumulativeObservableAnnotation':
return self
if len(self.qubits) == len(new_qubits):
pauli_map = dict(self.pauli_keys)
return CumulativeObservableAnnotation(
parity_keys=self.parity_keys,
relative_keys=self.relative_keys,
pauli_keys=tuple((new_q, pauli_map[q]) for new_q, q in zip(new_qubits, self.qubits)),
observable_index=self.observable_index,
)

raise ValueError("Number of qubits does not match")

def _value_equality_values_(self) -> Any:
return self.parity_keys, self.relative_keys, self.pauli_keys, self.observable_index
Expand Down Expand Up @@ -85,6 +101,7 @@ def _stim_conversion_(
edit_measurement_key_lengths: List[Tuple[str, int]],
have_seen_loop: bool = False,
tag: str,
targets: list[int],
**kwargs,
):
# Ideally these references would all be resolved ahead of time, to avoid the redundant
Expand All @@ -109,10 +126,11 @@ def _stim_conversion_(
rec_targets.append(stim.target_rec(-1 - offset))
if not remaining:
break
pauli_map = dict(self.pauli_keys)
rec_targets.extend(
[
stim.target_pauli(qubit_index=int(k[1:]), pauli=k[0])
for k in sorted(self.pauli_keys)
stim.target_pauli(qubit_index=tid, pauli=pauli_map[q])
for q, tid in zip(self.qubits, targets)
]
)
if remaining:
Expand Down
23 changes: 15 additions & 8 deletions glue/cirq/stimcirq/_obs_annotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ def test_json_serialization():
assert c == c2

def test_json_serialization_with_pauli_keys():
pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")]
c = cirq.Circuit(
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=pauli_keys
),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=pauli_keys),
)
json = cirq.to_json(c)
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
Expand All @@ -208,13 +209,19 @@ def test_json_serialization_with_pauli_keys():
def test_json_backwards_compat_exact():
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5)
packed_v1 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}'
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v2 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
assert cirq.read_json(json_text=packed_v1, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3

# With pauli_keys
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=["X0", "Y1", "Z2"])
pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")]
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=pauli_keys)
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n "X0",\n "Y1",\n "Z2"\n ],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n [\n {\n "cirq_type": "LineQubit",\n "x": 0\n },\n "X"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 1\n },\n "Y"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 2\n },\n "Z"\n ]\n ],\n "relative_keys": [\n -2\n ]\n}'

assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3
4 changes: 2 additions & 2 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,13 @@ def coords_after_offset(

def resolve_measurement_record_keys(
self, targets: Iterable[stim.GateTarget]
) -> Tuple[List[str], List[int], List[str]]:
) -> Tuple[List[str], List[int], List[tuple[cirq.Qid, str]]]:
pauli_targets, meas_targets = [], []
for t in targets:
if t.is_measurement_record_target:
meas_targets.append(t)
else:
pauli_targets.append(f'{t.pauli_type}{t.value}')
pauli_targets.append((cirq.LineQubit(t.value), t.pauli_type))

if self.have_seen_loop:
return [], [t.value for t in meas_targets], pauli_targets
Expand Down
1 change: 1 addition & 0 deletions glue/cirq/stimcirq/_stim_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def test_round_trip_with_pauli_obs():
stim_circuit = stim.Circuit("""
QUBIT_COORDS(5, 5) 0
R 0
TICK
OBSERVABLE_INCLUDE(0) X0
TICK
H 0
Expand Down
Loading