Skip to content

Commit 91b3424

Browse files
author
Jonas Rauber
committed
pow exponents can now be tensors
1 parent bc7490a commit 91b3424

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

eagerpy/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def square(t: TensorType) -> TensorType:
3232
return t.square()
3333

3434

35-
def pow(t: TensorType, exponent: float) -> TensorType:
35+
def pow(t: TensorType, exponent: TensorOrScalar) -> TensorType:
3636
return t.pow(exponent)
3737

3838

eagerpy/tensor/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __mod__(self: TensorType, other: TensorOrScalar) -> TensorType:
106106
return type(self)(self.raw.__mod__(unwrap1(other)))
107107

108108
@final
109-
def __pow__(self: TensorType, exponent: float) -> TensorType:
110-
return type(self)(self.raw.__pow__(exponent))
109+
def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType:
110+
return type(self)(self.raw.__pow__(unwrap1(exponent)))
111111

112112
@final
113113
@property

eagerpy/tensor/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType:
190190
...
191191

192192
@abstractmethod
193-
def __pow__(self: TensorType, exponent: float) -> TensorType:
193+
def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType:
194194
...
195195

196196
@abstractmethod
@@ -527,7 +527,7 @@ def abs(self: TensorType) -> TensorType:
527527
return self.__abs__()
528528

529529
@final
530-
def pow(self: TensorType, exponent: float) -> TensorType:
530+
def pow(self: TensorType, exponent: TensorOrScalar) -> TensorType:
531531
return self.__pow__(exponent)
532532

533533
@final

tests/test_main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,16 @@ def test_pow_op(t: Tensor) -> Tensor:
499499
return t ** 3
500500

501501

502+
@compare_allclose
503+
def test_pow_tensor(t: Tensor) -> Tensor:
504+
return ep.pow(t, (t + 0.5))
505+
506+
507+
@compare_allclose
508+
def test_pow_op_tensor(t: Tensor) -> Tensor:
509+
return t ** (t + 0.5)
510+
511+
502512
@compare_all
503513
def test_add(t1: Tensor, t2: Tensor) -> Tensor:
504514
return t1 + t2

0 commit comments

Comments
 (0)