Skip to content

Commit c3cec24

Browse files
committed
make everything easier.
1 parent e56f7e8 commit c3cec24

File tree

2 files changed

+3
-25
lines changed

2 files changed

+3
-25
lines changed

tensorcircuit/circuit.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def unitary_kraus(
391391
prob: Optional[Sequence[float]] = None,
392392
status: Optional[float] = None,
393393
name: Optional[str] = None,
394-
return_gate: bool = False,
395394
) -> Tensor:
396395
"""
397396
Apply unitary gates in ``kraus`` randomly based on corresponding ``prob``.
@@ -423,7 +422,6 @@ def index2gate(r: Tensor, kraus: Sequence[Tensor]) -> Tensor:
423422
status=status,
424423
get_gate_from_index=index2gate,
425424
name=name,
426-
return_gate=return_gate,
427425
)
428426

429427
def _unitary_kraus_template(
@@ -436,7 +434,6 @@ def _unitary_kraus_template(
436434
Callable[[Tensor, Sequence[Tensor]], Tensor]
437435
] = None,
438436
name: Optional[str] = None,
439-
return_gate: bool = False,
440437
) -> Tensor: # DRY
441438
sites = len(index)
442439
kraus = [k.tensor if isinstance(k, tn.Node) else k for k in kraus]
@@ -478,9 +475,7 @@ def step_function(x: Tensor) -> Tensor:
478475
raise ValueError("no `get_gate_from_index` implementation is provided")
479476
g = get_gate_from_index(r, kraus)
480477
g = backend.reshape(g, [self._d for _ in range(sites * 2)])
481-
if return_gate:
482-
return r, g
483-
self.any(*index, unitary=g, name=name) # type: ignore
478+
self.any(*index, unitary=g, name=name, dim=self._d) # type: ignore
484479
return r
485480

486481
def _general_kraus_tf(
@@ -558,7 +553,6 @@ def _general_kraus_2(
558553
status: Optional[float] = None,
559554
with_prob: bool = False,
560555
name: Optional[str] = None,
561-
return_gate: bool = False,
562556
) -> Tensor:
563557
# the graph building time is frustratingly slow, several minutes
564558
# though running time is in terms of ms
@@ -611,7 +605,6 @@ def calculate_kraus_p(i: int) -> Tensor:
611605
prob=prob,
612606
status=status,
613607
name=name,
614-
return_gate=return_gate,
615608
)
616609
if not with_prob:
617610
return pick
@@ -625,7 +618,6 @@ def general_kraus(
625618
status: Optional[float] = None,
626619
with_prob: bool = False,
627620
name: Optional[str] = None,
628-
return_gate: bool = False,
629621
) -> Tensor:
630622
"""
631623
Monte Carlo trajectory simulation of general Kraus channel whose Kraus operators cannot be
@@ -650,7 +642,6 @@ def general_kraus(
650642
status=status,
651643
with_prob=with_prob,
652644
name=name,
653-
return_gate=return_gate,
654645
)
655646

656647
apply_general_kraus = general_kraus

tensorcircuit/quditcircuit.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -695,24 +695,14 @@ def general_kraus(
695695
when the random number will be generated automatically
696696
:type status: Optional[float], optional
697697
"""
698-
res = self._circ.general_kraus(
698+
return self._circ.general_kraus(
699699
kraus,
700700
*index,
701701
status=status,
702702
with_prob=with_prob,
703703
name=name,
704-
return_gate=True,
705704
)
706705

707-
if with_prob:
708-
(pick, gate), prob = res
709-
self.any(*index, unitary=gate, name=name)
710-
return pick, prob
711-
else:
712-
pick, gate = res
713-
self.any(*index, unitary=gate, name=name)
714-
return pick
715-
716706
def unitary_kraus(
717707
self,
718708
kraus: Sequence[Gate],
@@ -734,13 +724,10 @@ def unitary_kraus(
734724
:return: shape [] int dtype tensor indicates which kraus gate is actually applied
735725
:rtype: Tensor
736726
"""
737-
r, g = self._circ.unitary_kraus(
727+
return self._circ.unitary_kraus(
738728
kraus,
739729
*index,
740730
prob=prob,
741731
status=status,
742732
name=name,
743-
return_gate=True,
744733
)
745-
self.any(*index, unitary=g, name=name)
746-
return r

0 commit comments

Comments
 (0)