Skip to content

Commit f22da3d

Browse files
committed
oops, fix tests
1 parent b7e4451 commit f22da3d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/gpu.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ end
4848
y = onehotbatch(ones(3), 1:2) |> cu;
4949
@test (repr("text/plain", y); true)
5050

51-
gA = rand(3, 2) |> cu;
52-
5351
#NOTE: this would require something that can compute gradient... we don't have that here?
5452
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
5553
end
5654

5755
@testset "LinearAlgebra" begin
56+
y = onehotbatch(ones(3), 1:2) |> cu;
57+
gA = rand(3, 2) |> cu;
58+
5859
# some specialized implementations call only mul! and not *, so we must ensure this works
5960
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) gA*y
6061
@test LinearAlgebra.mul!(similar(gA, 3, 1), gA, onehot(1, 1:2)) gA*onehot(1, 1:2)

0 commit comments

Comments
 (0)