Skip to content

Commit 3a1d8bb

Browse files
Channel from XY (#514)
**Context:** We would like to have the ability to initialize a Channel from its X,Y matrices (and potentially a d vector). **Description of the Change:** Added the method `.from_XY` which allows this computation. **Benefits:** Provides a simpler way for a user to initialize their desired channel. **Possible Drawbacks:** None. **Related GitHub Issues:** None.
1 parent e5c02ce commit 3a1d8bb

File tree

10 files changed

+144
-30
lines changed

10 files changed

+144
-30
lines changed

mrmustard/lab_dev/transformations/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz
3333
from mrmustard.physics.representations import Representation
3434
from mrmustard.physics.wires import Wires
35-
from mrmustard.utils.typing import ComplexMatrix
35+
from mrmustard.utils.typing import ComplexMatrix, RealMatrix, Vector
36+
from mrmustard.physics.triples import XY_to_channel_Abc
3637
from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel
3738
from ..circuit_components import CircuitComponent
3839

@@ -446,6 +447,37 @@ def from_quadrature(
446447
BB = QtoB_in >> QQ >> QtoB_out
447448
return Channel.from_ansatz(modes_out, modes_in, BB.ansatz, name)
448449

450+
@classmethod
451+
def from_XY(
452+
cls,
453+
modes_out: Sequence[int],
454+
modes_in: Sequence[int],
455+
X: RealMatrix,
456+
Y: RealMatrix,
457+
d: Vector | None = None,
458+
) -> Channel:
459+
r"""
460+
Initialize a Channel from its XY representation.
461+
Args:
462+
modes: The modes the channel is defined on.
463+
X: The X matrix of the channel.
464+
Y: The Y matrix of the channel.
465+
d: The d vector of the channel.
466+
467+
.. details::
468+
Each Gaussian channel transforms a state with covarince matrix :math:`\Sigma` and mean :math:`\mu`
469+
into a state with covariance matrix :math:`X \Sigma X^T + Y` and vector of means :math:`X\mu + d`.
470+
This channel has a Bargmann triple that is computed in https://arxiv.org/pdf/2209.06069. We borrow
471+
the formulas from the paper to implement the corresponding channel.
472+
"""
473+
474+
if X.shape != (2 * len(modes_out), 2 * len(modes_in)):
475+
raise ValueError(
476+
f"The dimension of X matrix ({X.shape}) and number of modes ({len(modes_in), len(modes_out)}) don't match."
477+
)
478+
479+
return Channel.from_bargmann(modes_out, modes_in, XY_to_channel_Abc(X, Y, d))
480+
449481
@classmethod
450482
def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel:
451483
r"""

mrmustard/physics/triples.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323

2424
from mrmustard import math, settings
25-
from mrmustard.utils.typing import Matrix, Vector, Scalar, RealMatrix
25+
from mrmustard.utils.typing import Matrix, Vector, Scalar, RealMatrix, ComplexMatrix
2626
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2
2727

2828

@@ -778,3 +778,64 @@ def attenuator_kraus_Abc(eta: float) -> Union[Matrix, Vector, Scalar]:
778778
b = _vacuum_B_vector(3)
779779
c = 1.0 + 0j
780780
return A, b, c
781+
782+
783+
def XY_to_channel_Abc(X: RealMatrix, Y: RealMatrix, d: Vector | None = None) -> ComplexMatrix:
784+
r"""
785+
The method to compute the A matrix of a channel based on its X, Y, and d.
786+
Args:
787+
X: The X matrix of the channel
788+
Y: The Y matrix of the channel
789+
d: The d (displacement) vector of the channel -- if None, we consider it as 0
790+
"""
791+
792+
m = Y.shape[-1] // 2
793+
# considering no displacement if d is None
794+
d = d if d else math.zeros(2 * m)
795+
796+
if X.shape != Y.shape:
797+
raise ValueError(
798+
"The dimension of X and Y matrices are not the same."
799+
f"X.shape = {X.shape}, Y.shape = {Y.shape}"
800+
)
801+
802+
xi = 1 / 2 * math.eye(2 * m, dtype=math.complex128) + 1 / 2 * X @ X.T + Y / settings.HBAR
803+
xi_inv = math.inv(xi)
804+
xi_inv_in_blocks = math.block(
805+
[[math.eye(2 * m) - xi_inv, xi_inv @ X], [X.T @ xi_inv, math.eye(2 * m) - X.T @ xi_inv @ X]]
806+
)
807+
R = (
808+
1
809+
/ math.sqrt(complex(2))
810+
* math.block(
811+
[
812+
[
813+
math.eye(m, dtype=math.complex128),
814+
1j * math.eye(m, dtype=math.complex128),
815+
math.zeros((m, 2 * m), dtype=math.complex128),
816+
],
817+
[
818+
math.zeros((m, 2 * m), dtype=math.complex128),
819+
math.eye(m, dtype=math.complex128),
820+
-1j * math.eye(m, dtype=math.complex128),
821+
],
822+
[
823+
math.eye(m, dtype=math.complex128),
824+
-1j * math.eye(m, dtype=math.complex128),
825+
math.zeros((m, 2 * m), dtype=math.complex128),
826+
],
827+
[
828+
math.zeros((m, 2 * m), dtype=math.complex128),
829+
math.eye(m, dtype=math.complex128),
830+
1j * math.eye(m, dtype=math.complex128),
831+
],
832+
]
833+
)
834+
)
835+
836+
A = math.Xmat(2 * m) @ R @ xi_inv_in_blocks @ math.conj(R).T
837+
temp = math.block([[(xi_inv @ d).reshape(2 * m, 1)], [(-X.T @ xi_inv @ d).reshape((2 * m, 1))]])
838+
b = 1 / math.sqrt(settings.HBAR) * math.conj(R) @ temp
839+
c = math.exp(-0.5 / settings.HBAR * d @ xi_inv @ d) / math.sqrt(math.det(xi))
840+
841+
return A, b, c

tests/test_lab_dev/test_circuit_components_utils/test_trace_out.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import pytest
2121

2222
from mrmustard import math
23-
from mrmustard.lab_dev.circuit_components_utils import TraceOut
2423
from mrmustard.lab_dev.circuit_components import CircuitComponent
2524
from mrmustard.lab_dev.circuit_components_utils import TraceOut
2625
from mrmustard.lab_dev.states import Coherent

tests/test_lab_dev/test_states/test_dm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement
1818

1919
from itertools import product
20+
2021
import numpy as np
2122
import pytest
2223

2324
from mrmustard import math, settings
2425
from mrmustard.lab_dev.circuit_components import CircuitComponent
2526
from mrmustard.lab_dev.circuit_components_utils import TraceOut
26-
from mrmustard.physics.gaussian import vacuum_cov
27-
from mrmustard.lab_dev.states import Coherent, DM, Ket, Number, Vacuum
27+
from mrmustard.lab_dev.states import DM, Coherent, Ket, Number, Vacuum
2828
from mrmustard.lab_dev.transformations import Attenuator, Dgate
29-
from mrmustard.physics.wires import Wires
29+
from mrmustard.physics.gaussian import vacuum_cov
3030
from mrmustard.physics.representations import Representation
31+
from mrmustard.physics.wires import Wires
3132

3233

3334
def coherent_state_quad(q, x, y, phi=0):

tests/test_lab_dev/test_states/test_ket.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,21 @@
1717
# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement
1818

1919
from itertools import product
20+
2021
import numpy as np
2122
import pytest
22-
23-
from ipywidgets import Box, HBox, VBox, HTML
23+
from ipywidgets import HTML, Box, HBox, VBox
2424
from plotly.graph_objs import FigureWidget
2525

2626
from mrmustard import math, settings
2727
from mrmustard.lab_dev.circuit_components import CircuitComponent
28-
from mrmustard.math.parameters import Constant, Variable
29-
from mrmustard.physics.gaussian import vacuum_cov, vacuum_means, squeezed_vacuum_cov
30-
from mrmustard.physics.triples import coherent_state_Abc
3128
from mrmustard.lab_dev.circuit_components_utils import TraceOut
32-
from mrmustard.lab_dev.states import (
33-
Coherent,
34-
DisplacedSqueezed,
35-
DM,
36-
Ket,
37-
Number,
38-
Vacuum,
39-
)
29+
from mrmustard.lab_dev.states import DM, Coherent, DisplacedSqueezed, Ket, Number, Vacuum
4030
from mrmustard.lab_dev.transformations import Attenuator, Dgate, Sgate
31+
from mrmustard.math.parameters import Constant, Variable
32+
from mrmustard.physics.gaussian import squeezed_vacuum_cov, vacuum_cov, vacuum_means
4133
from mrmustard.physics.representations import Representation
34+
from mrmustard.physics.triples import coherent_state_Abc
4235
from mrmustard.physics.wires import Wires
4336
from mrmustard.widgets import state as state_widget
4437

tests/test_lab_dev/test_transformations/test_transformations_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,11 @@ def test_XY(self):
204204

205205
X, Y = Attenuator([0], 0.2).XY
206206
assert np.allclose(X, np.sqrt(0.2) * np.eye(2)) and np.allclose(Y, 0.4 * np.eye(2))
207+
208+
@pytest.mark.parametrize("nmodes", [1, 2, 3])
209+
def test_from_XY(self, nmodes):
210+
X = np.random.random((2 * nmodes, 2 * nmodes))
211+
Y = np.random.random((2 * nmodes, 2 * nmodes))
212+
x, y = Channel.from_XY(range(nmodes), range(nmodes), X, Y).XY
213+
assert math.allclose(x, X)
214+
assert math.allclose(y, Y)

tests/test_physics/test_ansatz/test_array_ansatz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from unittest.mock import patch
2020

2121
import numpy as np
22-
from ipywidgets import HBox, VBox, HTML, Tab
23-
from plotly.graph_objs import FigureWidget
2422
import pytest
23+
from ipywidgets import HTML, HBox, Tab, VBox
24+
from plotly.graph_objs import FigureWidget
2525

2626
from mrmustard import math
2727
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz

tests/test_physics/test_ansatz/test_polyexp_ansatz.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@
1919
from unittest.mock import patch
2020

2121
import numpy as np
22-
from ipywidgets import Box, VBox, HTML, IntText, Stack, IntSlider
23-
from plotly.graph_objs import FigureWidget
2422
import pytest
23+
from ipywidgets import HTML, Box, IntSlider, IntText, Stack, VBox
24+
from plotly.graph_objs import FigureWidget
2525

2626
from mrmustard import math
27+
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz
28+
from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz
2729
from mrmustard.physics.gaussian_integrals import (
2830
complex_gaussian_integral_1,
2931
complex_gaussian_integral_2,
3032
)
31-
from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz
32-
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz
3333

3434
from ...random import Abc_triple
3535

tests/test_physics/test_representations.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
import pytest
2020

2121
from mrmustard import math
22-
23-
from mrmustard.physics.representations import Representation, RepEnum
24-
from mrmustard.physics.wires import Wires
2522
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
26-
from mrmustard.physics.triples import displacement_gate_Abc, bargmann_to_quadrature_Abc
23+
from mrmustard.physics.representations import RepEnum, Representation
24+
from mrmustard.physics.triples import bargmann_to_quadrature_Abc, displacement_gate_Abc
25+
from mrmustard.physics.wires import Wires
2726

2827
from ..random import Abc_triple
2928

tests/test_physics/test_triples.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import pytest
1919

20-
from mrmustard import math
20+
from mrmustard import math, settings
2121
from mrmustard.physics import triples
2222
from mrmustard.physics.ansatz import PolyExpAnsatz
2323

@@ -354,3 +354,24 @@ def test_gaussian_random_noise_Abc(self):
354354
assert math.allclose(A, A_by_hand)
355355
assert math.allclose(b, b_by_hand)
356356
assert math.allclose(c, c_by_hand)
357+
358+
def test_XY_to_channel_Abc(self):
359+
360+
# Creating an attenuator object and testing its Abc triple
361+
eta = np.random.random()
362+
X = np.sqrt(eta) * np.eye(2)
363+
Y = settings.HBAR / 2 * (1 - eta) * np.eye(2)
364+
365+
A, b, c = triples.XY_to_channel_Abc(X, Y)
366+
367+
A_by_hand = np.block(
368+
[
369+
[0, np.sqrt(eta), 0, 0],
370+
[np.sqrt(eta), 0, 0, 1 - eta],
371+
[0, 0, 0, np.sqrt(eta)],
372+
[0, 1 - eta, np.sqrt(eta), 0],
373+
]
374+
)
375+
assert np.allclose(A, A_by_hand)
376+
assert np.allclose(b, np.zeros((4, 1)))
377+
assert np.isclose(c, 1.0)

0 commit comments

Comments
 (0)