diff --git a/glue/cirq/stimcirq/_obs_annotation.py b/glue/cirq/stimcirq/_obs_annotation.py index 9709bc7cb..56b8d5389 100644 --- a/glue/cirq/stimcirq/_obs_annotation.py +++ b/glue/cirq/stimcirq/_obs_annotation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple, Union import cirq import stim @@ -16,7 +16,7 @@ def __init__( *, parity_keys: Iterable[str] = (), relative_keys: Iterable[int] = (), - pauli_keys: Iterable[str] = (), + pauli_keys: Union[Iterable[Tuple[cirq.Qid, str]], Iterable[str]] = (), observable_index: int, ): """ @@ -29,32 +29,64 @@ def __init__( """ self.parity_keys = frozenset(parity_keys) self.relative_keys = frozenset(relative_keys) - self.pauli_keys = frozenset(pauli_keys) + _pauli_keys = [] + _qubits_to_pauli_keys = [] + for k in pauli_keys: + if isinstance(k, str): + # For backward compatibility + _pauli_keys.append(k) + _qubits_to_pauli_keys.append((cirq.LineQubit(int(k[1:])), k)) + else: + qubit, basis_and_id = k + _pauli_keys.append(basis_and_id) + _qubits_to_pauli_keys.append((qubit, basis_and_id)) + self._qubits_to_pauli_keys = tuple(_qubits_to_pauli_keys) + self.pauli_keys = frozenset(_pauli_keys) self.observable_index = observable_index @property def qubits(self) -> Tuple[cirq.Qid, ...]: - return () + return tuple(sorted(q for q, _ in self._qubits_to_pauli_keys)) def with_qubits(self, *new_qubits) -> 'CumulativeObservableAnnotation': - return self + if len(self.qubits) == len(new_qubits): + qubits_to_pauli_keys = dict(self._qubits_to_pauli_keys) + return CumulativeObservableAnnotation( + parity_keys=self.parity_keys, + relative_keys=self.relative_keys, + pauli_keys=tuple( + (newq, qubits_to_pauli_keys[q]) for newq, q in zip(new_qubits, self.qubits) + ), + observable_index=self.observable_index, + ) + + raise ValueError("Number of qubits does not match") def _value_equality_values_(self) -> Any: - return self.parity_keys, self.relative_keys, self.pauli_keys, self.observable_index + return self.parity_keys, self.relative_keys, self._qubits_to_pauli_keys, self.observable_index - def _circuit_diagram_info_(self, args: Any) -> str: + def _circuit_diagram_info_(self, args: Any) -> Union[str, Tuple[str]]: items: List[str] = [repr(e) for e in sorted(self.parity_keys)] items += [f'rec[{e}]' for e in sorted(self.relative_keys)] - items += sorted(self.pauli_keys) - k = ",".join(str(e) for e in items) - return f"Obs{self.observable_index}({k})" + + if len(self._qubits_to_pauli_keys): + pauli_map = dict(self._qubits_to_pauli_keys) + out = [] + for q in self.qubits: + k = ",".join([str(e) for e in items] + [f'{str(q)}{pauli_map[q][0]}']) + out.append(f"Obs{self.observable_index}({k})") + return tuple(out) + else: + k = ",".join(str(e) for e in items) + return f"Obs{self.observable_index}({k})" + def __repr__(self) -> str: return ( f'stimcirq.CumulativeObservableAnnotation(' f'parity_keys={sorted(self.parity_keys)}, ' f'relative_keys={sorted(self.relative_keys)}, ' - f'pauli_keys={sorted(self.pauli_keys)}, ' + f'pauli_keys={sorted(self._qubits_to_pauli_keys)}, ' f'observable_index={self.observable_index!r})' ) @@ -66,7 +98,7 @@ def _json_dict_(self) -> Dict[str, Any]: result = { 'parity_keys': sorted(self.parity_keys), 'observable_index': self.observable_index, - 'pauli_keys': sorted(self.pauli_keys), + 'pauli_keys': sorted(self._qubits_to_pauli_keys), } if self.relative_keys: result['relative_keys'] = sorted(self.relative_keys) @@ -85,6 +117,7 @@ def _stim_conversion_( edit_measurement_key_lengths: List[Tuple[str, int]], have_seen_loop: bool = False, tag: str, + targets: list[int], **kwargs, ): # Ideally these references would all be resolved ahead of time, to avoid the redundant @@ -109,10 +142,13 @@ def _stim_conversion_( rec_targets.append(stim.target_rec(-1 - offset)) if not remaining: break + + qubit_to_basis = dict([(q,k[0]) for q, k in self._qubits_to_pauli_keys]) + rec_targets.extend( [ - stim.target_pauli(qubit_index=int(k[1:]), pauli=k[0]) - for k in sorted(self.pauli_keys) + stim.target_pauli(qubit_index=tid, pauli=qubit_to_basis[q]) + for q, tid in zip(self.qubits, targets) ] ) if remaining: diff --git a/glue/cirq/stimcirq/_obs_annotation_test.py b/glue/cirq/stimcirq/_obs_annotation_test.py index 6937c9fd1..d7b684886 100644 --- a/glue/cirq/stimcirq/_obs_annotation_test.py +++ b/glue/cirq/stimcirq/_obs_annotation_test.py @@ -192,13 +192,14 @@ def test_json_serialization(): assert c == c2 def test_json_serialization_with_pauli_keys(): + pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")] c = cirq.Circuit( - stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]), + stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=pauli_keys), stimcirq.CumulativeObservableAnnotation( - parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=["X0", "Y1", "Z2"] + parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=pauli_keys ), - stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=["X0", "Y1", "Z2"]), - stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]), + stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=pauli_keys), + stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=pauli_keys), ) json = cirq.to_json(c) c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) @@ -208,13 +209,19 @@ def test_json_serialization_with_pauli_keys(): def test_json_backwards_compat_exact(): raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5) packed_v1 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}' - packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}' + packed_v2 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}' + packed_v3 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}' assert cirq.read_json(json_text=packed_v1, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw - assert cirq.to_json(raw) == packed_v2 + assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw + assert cirq.to_json(raw) == packed_v3 # With pauli_keys - raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]) + pauli_keys = [(cirq.LineQubit(0), "X0"), (cirq.LineQubit(1), "Y1"), (cirq.LineQubit(2), "Z2")] + raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=pauli_keys) packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n "X0",\n "Y1",\n "Z2"\n ],\n "relative_keys": [\n -2\n ]\n}' + packed_v3 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n [\n {\n "cirq_type": "LineQubit",\n "x": 0\n },\n "X0"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 1\n },\n "Y1"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 2\n },\n "Z2"\n ]\n ],\n "relative_keys": [\n -2\n ]\n}' + assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw - assert cirq.to_json(raw) == packed_v2 \ No newline at end of file + assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw + assert cirq.to_json(raw) == packed_v3 \ No newline at end of file diff --git a/glue/cirq/stimcirq/_stim_to_cirq.py b/glue/cirq/stimcirq/_stim_to_cirq.py index 593bf4797..d88e5f449 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq.py +++ b/glue/cirq/stimcirq/_stim_to_cirq.py @@ -340,13 +340,13 @@ def coords_after_offset( def resolve_measurement_record_keys( self, targets: Iterable[stim.GateTarget] - ) -> Tuple[List[str], List[int], List[str]]: + ) -> Tuple[List[str], List[int], List[Tuple[cirq.Qid, str]]]: pauli_targets, meas_targets = [], [] for t in targets: if t.is_measurement_record_target: meas_targets.append(t) else: - pauli_targets.append(f'{t.pauli_type}{t.value}') + pauli_targets.append((cirq.LineQubit(t.value), f'{t.pauli_type}{t.value}')) if self.have_seen_loop: return [], [t.value for t in meas_targets], pauli_targets diff --git a/glue/cirq/stimcirq/_stim_to_cirq_test.py b/glue/cirq/stimcirq/_stim_to_cirq_test.py index facd79dd2..1a0d1c12d 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq_test.py +++ b/glue/cirq/stimcirq/_stim_to_cirq_test.py @@ -768,6 +768,7 @@ def test_round_trip_with_pauli_obs(): stim_circuit = stim.Circuit(""" QUBIT_COORDS(5, 5) 0 R 0 + TICK OBSERVABLE_INCLUDE(0) X0 TICK H 0