Skip to content

Commit 2fc674f

Browse files
arsalan-motamediapchytrziofil
authored
Wigner function for states (#581)
**Context:** Our current implementation of ``BtoPS`` is in the complex domain, and has extra factors one should be careful about. This unconventional description makes things less intuitive. **Description of the Change:** This PR creates a wigner ansatz for states that can be called by using ordinary phase space (x,p) conventions. **Benefits:** Making Wigner function calculations more intuitive and simple. **Possible Drawbacks:** Could be unstable for ``ArraysAnsatz``. **Related GitHub Issues:** None. --------- Co-authored-by: Anthony <[email protected]> Co-authored-by: Filippo Miatto <[email protected]>
1 parent 8abcda1 commit 2fc674f

File tree

6 files changed

+148
-3
lines changed

6 files changed

+148
-3
lines changed

mrmustard/lab/states/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444

4545
from ..circuit_components import CircuitComponent
46-
from ..circuit_components_utils import BtoChar, BtoQ
46+
from ..circuit_components_utils import BtoChar, BtoQ, BtoPS
4747
from ..transformations import Transformation
4848

4949
__all__ = ["State"]
@@ -758,3 +758,25 @@ def visualize_dm(
758758
if return_fig:
759759
return fig
760760
display(fig)
761+
762+
@property
763+
def wigner(self):
764+
r"""
765+
Returns the Wigner function of this state in phase space as an ``Ansatz``.
766+
767+
.. code-block::
768+
769+
>>> import numpy as np
770+
>>> from mrmustard.lab import Ket
771+
772+
>>> state = Ket.random([0])
773+
>>> x = np.linspace(-5, 5, 100)
774+
775+
>>> assert np.all(state.wigner(x,0).real >= 0)
776+
"""
777+
if isinstance(self.ansatz, PolyExpAnsatz):
778+
return (self >> BtoPS(self.modes, s=0)).ansatz.PS
779+
else:
780+
raise ValueError(
781+
"Wigner ansatz not implemented for Fock states. Consider calling ``.to_bargmann()`` first."
782+
)

mrmustard/physics/ansatz/polyexp_ansatz.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
join_Abc,
4545
)
4646
from mrmustard.physics.utils import generate_batch_str, verify_batch_triple
47+
from mrmustard.physics.fock_utils import c_in_PS
4748

4849
from mrmustard import math, widgets, settings
4950
from mrmustard.math.parameters import Variable
@@ -208,6 +209,47 @@ def num_derived_vars(self) -> int:
208209
def num_vars(self):
209210
return self.A.shape[-1]
210211

212+
@property
213+
def PS(self) -> PolyExpAnsatz:
214+
r"""
215+
The ansatz defined using real (i.e., phase-space) variables.
216+
"""
217+
n = self.A.shape[-1]
218+
if n % 2:
219+
raise ValueError(
220+
f"A phase space ansatz must have even number of indices. (n={n} is odd)"
221+
)
222+
223+
if self.num_derived_vars == 0:
224+
W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128)
225+
226+
A = math.einsum("ji, ...jk, kl-> ...il", W, self.A, W)
227+
b = math.einsum("ij, ...j-> ...i", W, self.b)
228+
c = self.c / (2 * settings.HBAR) ** (n // 2)
229+
return PolyExpAnsatz(A, b, c, lin_sup=self._lin_sup)
230+
231+
else:
232+
if self.num_derived_vars != 2:
233+
raise ValueError(
234+
f"This transformation supports 2 core and 0 or 2 derived variables"
235+
)
236+
A_tmp = self.A
237+
238+
A_tmp = A_tmp[..., [0, 2, 1, 3], :][..., [0, 2, 1, 3]]
239+
b = self.b[..., [0, 2, 1, 3]]
240+
c = c_in_PS(self.c) # implements PS transformations on ``c``
241+
242+
W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128)
243+
244+
A = math.einsum("ji, ...jk, kl-> ...il", W, A_tmp, W)
245+
b = math.einsum("ij, ...j-> ...i", W, b)
246+
c = c / (2 * settings.HBAR)
247+
248+
A_final = A[..., [0, 2, 1, 3], :][..., :, [0, 2, 1, 3]]
249+
b_final = b[..., [0, 2, 1, 3]]
250+
251+
return PolyExpAnsatz(A_final, b_final, c, lin_sup=self._lin_sup)
252+
211253
@property
212254
def scalar(self) -> Scalar:
213255
r"""
@@ -863,10 +905,10 @@ def __call__(self: PolyExpAnsatz, *z_inputs: ArrayLike | None) -> Batch[ComplexT
863905
- *b are the batch dimensions of the combined inputs.
864906
- *L is the batch shape of the ansatz.
865907
"""
866-
z_only = [arr for arr in z_inputs if arr is not None]
908+
z_only = [math.cast(arr, dtype=math.complex128) for arr in z_inputs if arr is not None]
867909
broadcasted_z = math.broadcast_arrays(*z_only)
868910
z = (
869-
math.cast(math.stack(broadcasted_z, axis=-1), dtype=math.complex128)
911+
math.stack(broadcasted_z, axis=-1)
870912
if broadcasted_z
871913
else math.astensor([], dtype=math.complex128)
872914
)

mrmustard/physics/fock_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Sequence, Iterable
2525

2626
import numpy as np
27+
from scipy.special import comb, factorial
2728

2829
from mrmustard import math, settings
2930
from mrmustard.math.lattice import strategies
@@ -512,3 +513,46 @@ def vjp(dLdGc):
512513
return math.astensor(dr, dtype=r.dtype), math.astensor(dphi, phi.dtype)
513514

514515
return ret, vjp
516+
517+
518+
def c_ps_matrix(m, n, alpha):
519+
"""
520+
helper function for ``c_in_PS``.
521+
"""
522+
mu_range = range(max(0, alpha - n), min(m, alpha) + 1)
523+
tmp = [comb(m, mu) * comb(n, alpha - mu) * (1j) ** (m - n - 2 * mu + alpha) for mu in mu_range]
524+
return np.sum(tmp)
525+
526+
527+
def gamma_matrix(c):
528+
"""
529+
helper function for ``c_in_PS`.
530+
constructs the matrix transformation that helps transforming ``c``.
531+
``c`` here must be 2-dimensional.
532+
"""
533+
M = c.shape[0] + c.shape[1] - 1
534+
Gamma = np.zeros((M**2, c.shape[0] * c.shape[1]), dtype=np.complex128)
535+
536+
for m in range(c.shape[0]):
537+
for n in range(c.shape[1]):
538+
for alpha in range(m + n + 1):
539+
factor = math.sqrt(
540+
factorial(m) * factorial(n) / (factorial(alpha) * factorial(m + n - alpha))
541+
)
542+
value = c_ps_matrix(m, n, alpha) * math.sqrt(settings.HBAR / 2) ** (m + n)
543+
row = alpha * M + (m + n - alpha)
544+
col = m * c.shape[0] + n
545+
Gamma[row, col] = value / factor
546+
return Gamma
547+
548+
549+
def c_in_PS(c):
550+
"""
551+
Transforms the ``c`` matrix of a ``DM`` object from bargmann to phase-space.
552+
It is a helper function used in
553+
554+
Args:
555+
c (Tensor): the 2-dimensional ``c`` matrix of the ``DM`` object
556+
"""
557+
M = c.shape[0] + c.shape[1] - 1
558+
return np.reshape(gamma_matrix(c) @ np.reshape(c, (c.shape[0] * c.shape[1], 1)), (M, M))

tests/test_lab/test_states/test_dm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,12 @@ def test_stellar_decomposition_mixed(self):
612612

613613
assert math.allclose(core.dm().contract(phi, mode="zip").ansatz.A, sigma.ansatz.A)
614614
assert math.allclose(core.dm().contract(phi, mode="zip").ansatz.b, sigma.ansatz.b)
615+
616+
def test_wigner(self):
617+
618+
ans = Vacuum(0).dm().wigner
619+
x = np.linspace(0, 1, 100)
620+
621+
solution = np.exp(-(x**2)) / np.pi
622+
623+
assert math.allclose(ans(x, 0), solution)

tests/test_lab/test_states/test_ket.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CircuitComponent,
3030
Coherent,
3131
Dgate,
32+
Ggate,
3233
Identity,
3334
Ket,
3435
Number,
@@ -39,6 +40,7 @@
3940
)
4041
from mrmustard.physics.representations import Representation
4142
from mrmustard.physics.triples import coherent_state_Abc
43+
from mrmustard.physics.wigner import wigner_discretized
4244
from mrmustard.physics.wires import Wires
4345
from mrmustard.widgets import state as state_widget
4446

@@ -618,3 +620,20 @@ def test_formal_stellar_decomposition(self):
618620
core, U = sigma.formal_stellar_decomposition([0])
619621

620622
assert sigma == core.contract(U, mode="zip")
623+
624+
def test_wigner(self):
625+
626+
ans = Vacuum(0).wigner
627+
x = np.linspace(0, 1, 100)
628+
solution = np.exp(-(x**2)) / np.pi
629+
630+
assert math.allclose(ans(x, 0), solution)
631+
632+
@pytest.mark.parametrize("n", [1, 2, 3])
633+
def test_wigner_poly_exp(self, n):
634+
635+
psi = (Number(0, n).dm().to_bargmann()) >> Ggate(0)
636+
xs = np.linspace(-5, 5, 100)
637+
poly_exp_wig = math.real(psi.wigner(xs, 0))
638+
wig = wigner_discretized(psi.fock_array(), xs, 0)
639+
assert math.allclose(poly_exp_wig[:, None], wig[0], atol=3e-3)

tests/test_physics/test_ansatz/test_polyexp_ansatz.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
complex_gaussian_integral_1,
3030
complex_gaussian_integral_2,
3131
)
32+
from mrmustard.lab.transformations import Identity
3233

3334
from ...random import Abc_triple
3435

@@ -624,3 +625,11 @@ def test_and_with_lin_sup_other(self):
624625
assert ansatz_and.A.shape == (2, 5, 3, 4, 3, 3)
625626
assert ansatz_and.b.shape == (2, 5, 3, 4, 3)
626627
assert ansatz_and.c.shape == (2, 5, 3, 4)
628+
629+
def test_PS(self):
630+
ans = Identity(0).ansatz
631+
632+
x = np.linspace(0, 1, 10)
633+
gaussian = np.exp((x**2) / 2) / 2
634+
635+
assert math.allclose(ans.PS(x, 0), gaussian)

0 commit comments

Comments
 (0)