Skip to content
This repository was archived by the owner on May 10, 2025. It is now read-only.

Commit d6ff74a

Browse files
GiggleLiuRoger-luo
andauthored
update CUDA (#69)
* update CUDA * cleanup * update version Co-authored-by: Roger-Luo <[email protected]>
1 parent 7532cbf commit d6ff74a

File tree

6 files changed

+9
-34
lines changed

6 files changed

+9
-34
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CuYao"
22
uuid = "b48ca7a8-dd42-11e8-2b8e-1b7706800275"
3-
version = "0.2.9"
3+
version = "0.3.0"
44

55
[deps]
66
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
@@ -16,14 +16,14 @@ Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"
1616

1717
[compat]
1818
BitBasis = "0.7"
19-
CUDA = "2.0, 3.0"
19+
CUDA = "3.1"
2020
LuxurySparse = "0.6"
2121
Reexport = "0.2, 1.0"
2222
StaticArrays = "0.12, 1.0"
2323
StatsBase = "0.33"
2424
TupleTools = "1"
2525
Yao = "0.6, 0.7"
26-
julia = "1"
26+
julia = "1.6"
2727

2828
[extras]
2929
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

src/CUDApatch.jl

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,3 @@
1-
#import CUDA: _cuview, ViewIndex, NonContiguous
2-
#using CUDA: genperm
3-
# fallback to SubArray when the view is not contiguous
4-
5-
#=
6-
function LinearAlgebra.permutedims!(dest::GPUArray, src::GPUArray, perm) where N
7-
perm isa Tuple || (perm = Tuple(perm))
8-
gpu_call(dest, (dest, src, perm)) do state, dest, src, perm
9-
I = @cartesianidx src state
10-
@inbounds dest[genperm(I, perm)...] = src[I...]
11-
return
12-
end
13-
return dest
14-
end
15-
=#
16-
17-
import CUDA: pow, abs, angle
18-
for (RT, CT) in [(:Float64, :ComplexF64), (:Float32, :ComplexF32)]
19-
@eval cp2c(d::$RT, a::$RT) = CUDA.Complex(d*CUDA.cos(a), d*CUDA.sin(a))
20-
for NT in [RT, :Int32]
21-
@eval CUDA.pow(z::$CT, n::$NT) = CUDA.Complex((CUDA.pow(CUDA.abs(z), n)*CUDA.cos(n*CUDA.angle(z))), (CUDA.pow(CUDA.abs(z), n)*CUDA.sin(n*CUDA.angle(z))))
22-
end
23-
end
24-
251
@inline function bit_count(x::UInt32)
262
x = ((x >> 1) & 0b01010101010101010101010101010101) + (x & 0b01010101010101010101010101010101)
273
x = ((x >> 2) & 0b00110011001100110011001100110011) + (x & 0b00110011001100110011001100110011)
@@ -95,7 +71,7 @@ end
9571
Computes Kronecker products in-place on the GPU.
9672
The results are stored in 'C', overwriting the existing values of 'C'.
9773
"""
98-
function kron!(C::CuArray{T3}, A::DenseCuArray{T1}, B::DenseCuArray{T2}) where {T1, T2, T3}
74+
function Yao.YaoBase.kron!(C::CuArray{T3}, A::DenseCuArray{T1}, B::DenseCuArray{T2}) where {T1, T2, T3}
9975
@boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch())
10076
CI = Base.CartesianIndices(C)
10177
@inline function kernel(C, A, B)

src/CuYao.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using Random
88

99
using Yao.YaoArrayRegister
1010
using CUDA
11-
import Yao: kron!
1211
@reexport using Yao
1312

1413
const Ints = NTuple{<:Any, Int}

src/kernels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
mask = bmask(Int32, bits...)
107107
1<<nbit,@inline function kernel(state, inds)
108108
i = inds[1]
109-
piecewise(state, inds)[i] *= CUDA.pow(d, bit_count(Int32(i-1)&mask))
109+
piecewise(state, inds)[i] *= d ^ bit_count(Int32(i-1)&mask)
110110
return
111111
end
112112
end

test/CUDApatch.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
@testset "Complex pow" begin
2626
for T in [ComplexF64, ComplexF32]
2727
a = CuArray(randn(T, 4, 4))
28-
@test Array(CUDA.pow.(a, Int32(3))) Array(a).^3
29-
@test Array(CUDA.pow.(a, real(T)(3))) Array(a).^3
28+
@test Array(a .^ Int32(3)) Array(a).^3
29+
@test Array(a .^ real(T)(3)) Array(a).^3
3030
end
3131
end
3232

test/GPUReg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ end
108108
c = zeros(12,8)
109109
ca, cb, cc = cu(a), cu(b), cu(c)
110110
@test kron(ca, cb) |> Array kron(a, b)
111-
@test kron!(cc, ca, cb) |> Array kron(a,b)
111+
@test Yao.YaoBase.kron!(cc, ca, cb) |> Array kron(a,b)
112112

113-
kron!(c,a,b)
113+
Yao.YaoBase.kron!(c,a,b)
114114
@test cc |> Array c
115115

116116
v = randn(100) |> cu

0 commit comments

Comments
 (0)