Skip to content

Commit 9457b17

Browse files
committed
Outer works generically for CPU and GPU ( I also verified that Complex works on GPU though diag only tests real valued tensors)
1 parent 0084657 commit 9457b17

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

NDTensors/src/diag/tensoralgebra/outer.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
function outer!(
22
R::DenseTensor{<:Number, NR}, T1::DiagTensor{<:Number, N1}, T2::DiagTensor{<:Number, N2}
33
) where {NR, N1, N2}
4-
for i1 in 1:diaglength(T1), i2 in 1:diaglength(T2)
5-
indsR = CartesianIndex{NR}(ntuple(r -> r N1 ? i1 : i2, Val(NR)))
6-
R[indsR] = getdiagindex(T1, i1) * getdiagindex(T2, i2)
7-
end
4+
t1 = T1.storage.data
5+
t2 = T2.storage.data
6+
array(R) .= t1 .* t2'
87
return R
98
end
109

NDTensors/test/test_diag.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ end
123123

124124
## Test dot on GPU
125125
@test dot(t, A) dot(dev(array(t)), array(A)) rtol = sqrt(eps(elt))
126+
127+
NDTensors.outer!(A, t,t);
128+
for i in NDTensors.cpu(A)
129+
@test i == one(elt)
130+
end
126131
end
127132
nothing
128133
end

0 commit comments

Comments
 (0)