Skip to content

Commit d53e625

Browse files
authored
Update tensors.py
1 parent 987feac commit d53e625

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

mrmustard/math/tensor_networks/tensors.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ class Wire:
5555
is_input: bool
5656
is_ket: bool
5757

58-
_contraction_id: int = field(default_factory=random_int, init=False)
59-
_dim: int | None = field(default=None, init=False)
60-
_is_connected: bool = field(default=False, init=False)
58+
def __post_init__(self):
59+
self._contraction_id: int = random_int()
60+
self._dim = None
61+
self._is_connected = False
6162

6263
@property
6364
def contraction_id(self) -> int:
@@ -79,7 +80,7 @@ def dim(self):
7980

8081
@dim.setter
8182
def dim(self, value: int):
82-
if self._dim is not None:
83+
if self._dim:
8384
raise ValueError("Cannot change the dimension of wire with specified dimension.")
8485
self._dim = value
8586

@@ -186,16 +187,18 @@ def _update_modes(
186187
self._modes_in_bra = modes_in_bra if modes_in_bra else []
187188
self._modes_out_bra = modes_out_bra if modes_out_bra else []
188189

189-
# initialize ket and bra wire dicts using dictionary comprehensions for better performance
190-
self._input = WireGroup(
191-
ket={mode: Wire(random_int(), mode, True, True) for mode in self._modes_in_ket},
192-
bra={mode: Wire(random_int(), mode, True, False) for mode in self._modes_in_bra},
193-
)
190+
# initialize ket and bra wire dicts
191+
self._input = WireGroup()
192+
for mode in self._modes_in_ket:
193+
self._input.ket |= {mode: Wire(random_int(), mode, True, True)}
194+
for mode in self._modes_in_bra:
195+
self._input.bra |= {mode: Wire(random_int(), mode, True, False)}
194196

195-
self._output = WireGroup(
196-
ket={mode: Wire(random_int(), mode, False, True) for mode in self._modes_out_ket},
197-
bra={mode: Wire(random_int(), mode, False, False) for mode in self._modes_out_bra},
198-
)
197+
self._output = WireGroup()
198+
for mode in self._modes_out_ket:
199+
self._output.ket |= {mode: Wire(random_int(), mode, False, True)}
200+
for mode in self._modes_out_bra:
201+
self._output.bra |= {mode: Wire(random_int(), mode, False, False)}
199202

200203
@property
201204
def adjoint(self) -> AdjointView:
@@ -351,25 +354,30 @@ def shape(self, default_dim: int | None = None, out_in=False):
351354
Returns the shape of the underlying tensor, as inferred from the dimensions of the individual
352355
wires.
353356
354-
If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)``
357+
If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)``.
355358
Otherwise, it is in the order ``(out_ket, out_bra, in_ket, in_bra)``.
356359
357360
Args:
358361
default_dim: The default dimension of wires with unspecified dimension.
359362
out_in: Whether to return output shapes followed by input shapes or viceversa.
360363
"""
361364

365+
def _sort_shapes(*args):
366+
for arg in args:
367+
if arg:
368+
yield arg
369+
362370
shape_in_ket = [w.dim if w.dim else default_dim for w in self.input.ket.values()]
363371
shape_out_ket = [w.dim if w.dim else default_dim for w in self.output.ket.values()]
364372
shape_in_bra = [w.dim if w.dim else default_dim for w in self.input.bra.values()]
365373
shape_out_bra = [w.dim if w.dim else default_dim for w in self.output.bra.values()]
366374

367375
if out_in:
368-
combined_shape = shape_out_ket + shape_out_bra + shape_in_ket + shape_in_bra
369-
else:
370-
combined_shape = shape_in_ket + shape_in_bra + shape_out_ket + shape_out_bra
376+
ret = _sort_shapes(shape_out_ket, shape_out_bra, shape_in_ket, shape_in_bra)
377+
ret = _sort_shapes(shape_in_ket, shape_in_bra, shape_out_ket, shape_out_bra)
371378

372-
return tuple(combined_shape)
379+
# pylint: disable=consider-using-generator
380+
return tuple([item for sublist in ret for item in sublist])
373381

374382

375383
class AdjointView(Tensor):
@@ -381,10 +389,10 @@ def __init__(self, tensor):
381389
self._original = tensor
382390
super().__init__(
383391
name=self._original.name,
384-
modes_in_ket=list(self._original.input.bra.keys()),
385-
modes_out_ket=list(self._original.output.bra.keys()),
386-
modes_in_bra=list(self._original.input.ket.keys()),
387-
modes_out_bra=list(self._original.output.ket.keys()),
392+
modes_in_ket=self._original.input.bra.keys(),
393+
modes_out_ket=self._original.output.bra.keys(),
394+
modes_in_bra=self._original.input.ket.keys(),
395+
modes_out_bra=self._original.output.ket.keys(),
388396
)
389397

390398
def value(self, shape: tuple[int]):
@@ -397,7 +405,12 @@ def value(self, shape: tuple[int]):
397405
ComplexTensor: the unitary matrix in Fock representation
398406
"""
399407
# converting the given shape into a shape for the original tensor
400-
shape_in_ket, shape_out_ket, shape_in_bra, shape_out_bra = self._original.unpack_shape(shape)
408+
(
409+
shape_in_ket,
410+
shape_out_ket,
411+
shape_in_bra,
412+
shape_out_bra,
413+
) = self._original.unpack_shape(shape)
401414
shape_ret = shape_in_bra + shape_out_bra + shape_in_ket + shape_out_ket
402415

403416
ret = math.conj(math.astensor(self._original.value(shape_ret)))
@@ -413,10 +426,10 @@ def __init__(self, tensor):
413426
self._original = tensor
414427
super().__init__(
415428
name=self._original.name,
416-
modes_in_ket=list(self._original.output.ket.keys()),
417-
modes_out_ket=list(self._original.input.ket.keys()),
418-
modes_in_bra=list(self._original.output.bra.keys()),
419-
modes_out_bra=list(self._original.input.bra.keys()),
429+
modes_in_ket=self._original.output.ket.keys(),
430+
modes_out_ket=self._original.input.ket.keys(),
431+
modes_in_bra=self._original.output.bra.keys(),
432+
modes_out_bra=self._original.input.bra.keys(),
420433
)
421434

422435
def value(self, shape: tuple[int]):
@@ -430,6 +443,6 @@ def value(self, shape: tuple[int]):
430443
"""
431444
# converting the given shape into a shape for the original tensor
432445
shape_in_ket, shape_out_ket, shape_in_bra, shape_out_bra = self.unpack_shape(shape)
433-
shape_ret = shape_out_ket + shape_in_ket + shape_out_bra + shape_in_bra
446+
shape_ret = shape_out_ket + shape_in_ket + shape_out_bra, shape_in_bra
434447

435448
return math.conj(self._original.value(shape_ret))

0 commit comments

Comments
 (0)