@@ -55,9 +55,10 @@ class Wire:
5555    is_input : bool 
5656    is_ket : bool 
5757
58-     _contraction_id : int  =  field (default_factory = random_int , init = False )
59-     _dim : int  |  None  =  field (default = None , init = False )
60-     _is_connected : bool  =  field (default = False , init = False )
58+     def  __post_init__ (self ):
59+         self ._contraction_id : int  =  random_int ()
60+         self ._dim  =  None 
61+         self ._is_connected  =  False 
6162
6263    @property  
6364    def  contraction_id (self ) ->  int :
@@ -79,7 +80,7 @@ def dim(self):
7980
8081    @dim .setter  
8182    def  dim (self , value : int ):
82-         if  self ._dim   is   not   None :
83+         if  self ._dim :
8384            raise  ValueError ("Cannot change the dimension of wire with specified dimension." )
8485        self ._dim  =  value 
8586
@@ -186,16 +187,18 @@ def _update_modes(
186187        self ._modes_in_bra  =  modes_in_bra  if  modes_in_bra  else  []
187188        self ._modes_out_bra  =  modes_out_bra  if  modes_out_bra  else  []
188189
189-         # initialize ket and bra wire dicts using dictionary comprehensions for better performance 
190-         self ._input  =  WireGroup (
191-             ket = {mode : Wire (random_int (), mode , True , True ) for  mode  in  self ._modes_in_ket },
192-             bra = {mode : Wire (random_int (), mode , True , False ) for  mode  in  self ._modes_in_bra },
193-         )
190+         # initialize ket and bra wire dicts 
191+         self ._input  =  WireGroup ()
192+         for  mode  in  self ._modes_in_ket :
193+             self ._input .ket  |=  {mode : Wire (random_int (), mode , True , True )}
194+         for  mode  in  self ._modes_in_bra :
195+             self ._input .bra  |=  {mode : Wire (random_int (), mode , True , False )}
194196
195-         self ._output  =  WireGroup (
196-             ket = {mode : Wire (random_int (), mode , False , True ) for  mode  in  self ._modes_out_ket },
197-             bra = {mode : Wire (random_int (), mode , False , False ) for  mode  in  self ._modes_out_bra },
198-         )
197+         self ._output  =  WireGroup ()
198+         for  mode  in  self ._modes_out_ket :
199+             self ._output .ket  |=  {mode : Wire (random_int (), mode , False , True )}
200+         for  mode  in  self ._modes_out_bra :
201+             self ._output .bra  |=  {mode : Wire (random_int (), mode , False , False )}
199202
200203    @property  
201204    def  adjoint (self ) ->  AdjointView :
@@ -351,25 +354,30 @@ def shape(self, default_dim: int | None = None, out_in=False):
351354        Returns the shape of the underlying tensor, as inferred from the dimensions of the individual 
352355        wires. 
353356
354-         If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)`` 
357+         If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)``.  
355358        Otherwise, it is in the order ``(out_ket, out_bra, in_ket, in_bra)``. 
356359
357360        Args: 
358361            default_dim: The default dimension of wires with unspecified dimension. 
359362            out_in: Whether to return output shapes followed by input shapes or viceversa. 
360363        """ 
361364
365+         def  _sort_shapes (* args ):
366+             for  arg  in  args :
367+                 if  arg :
368+                     yield  arg 
369+ 
362370        shape_in_ket  =  [w .dim  if  w .dim  else  default_dim  for  w  in  self .input .ket .values ()]
363371        shape_out_ket  =  [w .dim  if  w .dim  else  default_dim  for  w  in  self .output .ket .values ()]
364372        shape_in_bra  =  [w .dim  if  w .dim  else  default_dim  for  w  in  self .input .bra .values ()]
365373        shape_out_bra  =  [w .dim  if  w .dim  else  default_dim  for  w  in  self .output .bra .values ()]
366374
367375        if  out_in :
368-             combined_shape  =  shape_out_ket  +  shape_out_bra  +  shape_in_ket  +  shape_in_bra 
369-         else :
370-             combined_shape  =  shape_in_ket  +  shape_in_bra  +  shape_out_ket  +  shape_out_bra 
376+             ret  =  _sort_shapes (shape_out_ket , shape_out_bra , shape_in_ket , shape_in_bra )
377+         ret  =  _sort_shapes (shape_in_ket , shape_in_bra , shape_out_ket , shape_out_bra )
371378
372-         return  tuple (combined_shape )
379+         # pylint: disable=consider-using-generator 
380+         return  tuple ([item  for  sublist  in  ret  for  item  in  sublist ])
373381
374382
375383class  AdjointView (Tensor ):
@@ -381,10 +389,10 @@ def __init__(self, tensor):
381389        self ._original  =  tensor 
382390        super ().__init__ (
383391            name = self ._original .name ,
384-             modes_in_ket = list ( self ._original .input .bra .keys () ),
385-             modes_out_ket = list ( self ._original .output .bra .keys () ),
386-             modes_in_bra = list ( self ._original .input .ket .keys () ),
387-             modes_out_bra = list ( self ._original .output .ket .keys () ),
392+             modes_in_ket = self ._original .input .bra .keys (),
393+             modes_out_ket = self ._original .output .bra .keys (),
394+             modes_in_bra = self ._original .input .ket .keys (),
395+             modes_out_bra = self ._original .output .ket .keys (),
388396        )
389397
390398    def  value (self , shape : tuple [int ]):
@@ -397,7 +405,12 @@ def value(self, shape: tuple[int]):
397405            ComplexTensor: the unitary matrix in Fock representation 
398406        """ 
399407        # converting the given shape into a shape for the original tensor 
400-         shape_in_ket , shape_out_ket , shape_in_bra , shape_out_bra  =  self ._original .unpack_shape (shape )
408+         (
409+             shape_in_ket ,
410+             shape_out_ket ,
411+             shape_in_bra ,
412+             shape_out_bra ,
413+         ) =  self ._original .unpack_shape (shape )
401414        shape_ret  =  shape_in_bra  +  shape_out_bra  +  shape_in_ket  +  shape_out_ket 
402415
403416        ret  =  math .conj (math .astensor (self ._original .value (shape_ret )))
@@ -413,10 +426,10 @@ def __init__(self, tensor):
413426        self ._original  =  tensor 
414427        super ().__init__ (
415428            name = self ._original .name ,
416-             modes_in_ket = list ( self ._original .output .ket .keys () ),
417-             modes_out_ket = list ( self ._original .input .ket .keys () ),
418-             modes_in_bra = list ( self ._original .output .bra .keys () ),
419-             modes_out_bra = list ( self ._original .input .bra .keys () ),
429+             modes_in_ket = self ._original .output .ket .keys (),
430+             modes_out_ket = self ._original .input .ket .keys (),
431+             modes_in_bra = self ._original .output .bra .keys (),
432+             modes_out_bra = self ._original .input .bra .keys (),
420433        )
421434
422435    def  value (self , shape : tuple [int ]):
@@ -430,6 +443,6 @@ def value(self, shape: tuple[int]):
430443        """ 
431444        # converting the given shape into a shape for the original tensor 
432445        shape_in_ket , shape_out_ket , shape_in_bra , shape_out_bra  =  self .unpack_shape (shape )
433-         shape_ret  =  shape_out_ket  +  shape_in_ket  +  shape_out_bra   +  shape_in_bra 
446+         shape_ret  =  shape_out_ket  +  shape_in_ket  +  shape_out_bra ,  shape_in_bra 
434447
435448        return  math .conj (self ._original .value (shape_ret ))
0 commit comments