|
44 | 44 | join_Abc, |
45 | 45 | ) |
46 | 46 | from mrmustard.physics.utils import generate_batch_str, verify_batch_triple |
| 47 | +from mrmustard.physics.fock_utils import c_in_PS |
47 | 48 |
|
48 | 49 | from mrmustard import math, widgets, settings |
49 | 50 | from mrmustard.math.parameters import Variable |
@@ -208,6 +209,47 @@ def num_derived_vars(self) -> int: |
208 | 209 | def num_vars(self): |
209 | 210 | return self.A.shape[-1] |
210 | 211 |
|
| 212 | + @property |
| 213 | + def PS(self) -> PolyExpAnsatz: |
| 214 | + r""" |
| 215 | + The ansatz defined using real (i.e., phase-space) variables. |
| 216 | + """ |
| 217 | + n = self.A.shape[-1] |
| 218 | + if n % 2: |
| 219 | + raise ValueError( |
| 220 | + f"A phase space ansatz must have even number of indices. (n={n} is odd)" |
| 221 | + ) |
| 222 | + |
| 223 | + if self.num_derived_vars == 0: |
| 224 | + W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128) |
| 225 | + |
| 226 | + A = math.einsum("ji, ...jk, kl-> ...il", W, self.A, W) |
| 227 | + b = math.einsum("ij, ...j-> ...i", W, self.b) |
| 228 | + c = self.c / (2 * settings.HBAR) ** (n // 2) |
| 229 | + return PolyExpAnsatz(A, b, c, lin_sup=self._lin_sup) |
| 230 | + |
| 231 | + else: |
| 232 | + if self.num_derived_vars != 2: |
| 233 | + raise ValueError( |
| 234 | + f"This transformation supports 2 core and 0 or 2 derived variables" |
| 235 | + ) |
| 236 | + A_tmp = self.A |
| 237 | + |
| 238 | + A_tmp = A_tmp[..., [0, 2, 1, 3], :][..., [0, 2, 1, 3]] |
| 239 | + b = self.b[..., [0, 2, 1, 3]] |
| 240 | + c = c_in_PS(self.c) # implements PS transformations on ``c`` |
| 241 | + |
| 242 | + W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128) |
| 243 | + |
| 244 | + A = math.einsum("ji, ...jk, kl-> ...il", W, A_tmp, W) |
| 245 | + b = math.einsum("ij, ...j-> ...i", W, b) |
| 246 | + c = c / (2 * settings.HBAR) |
| 247 | + |
| 248 | + A_final = A[..., [0, 2, 1, 3], :][..., :, [0, 2, 1, 3]] |
| 249 | + b_final = b[..., [0, 2, 1, 3]] |
| 250 | + |
| 251 | + return PolyExpAnsatz(A_final, b_final, c, lin_sup=self._lin_sup) |
| 252 | + |
211 | 253 | @property |
212 | 254 | def scalar(self) -> Scalar: |
213 | 255 | r""" |
@@ -863,10 +905,10 @@ def __call__(self: PolyExpAnsatz, *z_inputs: ArrayLike | None) -> Batch[ComplexT |
863 | 905 | - *b are the batch dimensions of the combined inputs. |
864 | 906 | - *L is the batch shape of the ansatz. |
865 | 907 | """ |
866 | | - z_only = [arr for arr in z_inputs if arr is not None] |
| 908 | + z_only = [math.cast(arr, dtype=math.complex128) for arr in z_inputs if arr is not None] |
867 | 909 | broadcasted_z = math.broadcast_arrays(*z_only) |
868 | 910 | z = ( |
869 | | - math.cast(math.stack(broadcasted_z, axis=-1), dtype=math.complex128) |
| 911 | + math.stack(broadcasted_z, axis=-1) |
870 | 912 | if broadcasted_z |
871 | 913 | else math.astensor([], dtype=math.complex128) |
872 | 914 | ) |
|
0 commit comments