@@ -9,70 +9,68 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
99# Metal sometimes supports fewer.
1010const TILE_DIM = 16
1111
12- @kernel unsafe_indices = true function coalesced_matmul_kernel! (
13- output, @Const (input1), @Const (input2), N, R, M,
14- :: Val{BANK} = Val (1 ),
15- ) where {BANK}
16- gi, gj = @index (Group, NTuple)
17- i, j = @index (Local, NTuple)
18-
19- TILE_DIM = @uniform @groupsize ()[1 ]
12+ function coalesced_matmul_kernel! (
13+ output, input1, input2, N, R, M,
14+ :: Val{TDIM} , :: Val{BANK} = Val (1 )
15+ ) where {TDIM, BANK}
16+ gi, gj, _ = KI. get_group_id ()
17+ i, j, _ = KI. get_local_id ()
2018
2119 # +1 to avoid bank conflicts on shared memory
22- tile1 = @localmem eltype (output) (TILE_DIM + BANK, TILE_DIM )
23- tile2 = @localmem eltype (output) (TILE_DIM + BANK, TILE_DIM )
20+ tile1 = KI . localmemory ( eltype (output), (TDIM + BANK, TDIM) )
21+ tile2 = KI . localmemory ( eltype (output), (TDIM + BANK, TDIM) )
2422
25- # private variable for tile output
26- outval = @private eltype (output) 1
27- @inbounds outval[1 ] = - zero (eltype (output))
23+ # variable for tile output
24+ outval = - zero (eltype (output))
2825
29- @uniform N = size (output, 1 )
26+ N = size (output, 1 )
3027 # number of tiles depends on inner dimension
31- @uniform NUM_TILES = div (R + TILE_DIM - 1 , TILE_DIM )
28+ NUM_TILES = div (R + TDIM - 1 , TDIM )
3229
3330 # loop over all tiles needed for this calculation
3431 for t in 0 : (NUM_TILES - 1 )
3532 # Can't use @index(Global), because we use a smaller ndrange
36- I = (gi - 1 ) * TILE_DIM + i
37- J = (gj - 1 ) * TILE_DIM + j
33+ I = (gi - 1 ) * TDIM + i
34+ J = (gj - 1 ) * TDIM + j
3835
3936 # load inputs into tiles, with bounds checking for non-square matrices
40- if I <= N && t * TILE_DIM + j <= R
41- @inbounds tile1[i, j] = input1[I, t * TILE_DIM + j]
37+ if I <= N && t * TDIM + j <= R
38+ @inbounds tile1[i, j] = input1[I, t * TDIM + j]
4239 else
4340 @inbounds tile1[i, j] = 0.0
4441 end
4542 if t * TILE_DIM + i <= R && J <= M
46- @inbounds tile2[i, j] = input2[t * TILE_DIM + i, J]
43+ @inbounds tile2[i, j] = input2[t * TDIM + i, J]
4744 else
4845 @inbounds tile2[i, j] = 0.0
4946 end
5047
5148 # wait for all tiles to be loaded
52- @synchronize
49+ KI . barrier ()
5350
5451 # get global values again
55- I = (gi - 1 ) * TILE_DIM + i
56- J = (gj - 1 ) * TILE_DIM + j
52+ I = (gi - 1 ) * TDIM + i
53+ J = (gj - 1 ) * TDIM + j
5754
5855 # calculate value of spot in output, use temporary value to allow for vectorization
5956 out = zero (eltype (output))
60- @simd for k in 1 : TILE_DIM
57+ @simd for k in 1 : TDIM
6158 @inbounds out += tile1[i, k] * tile2[k, j]
6259 end
63- outval[ 1 ] += out
60+ outval += out
6461
65- @synchronize
62+ KI . barrier ()
6663 end
6764
6865 # get global indices again
69- I = (gi - 1 ) * TILE_DIM + i
70- J = (gj - 1 ) * TILE_DIM + j
66+ I = (gi - 1 ) * TDIM + i
67+ J = (gj - 1 ) * TDIM + j
7168
7269 # save if inbounds
7370 if I <= N && J <= M
74- @inbounds output[I, J] = outval[ 1 ]
71+ @inbounds output[I, J] = outval
7572 end
73+ return nothing
7674end
7775
7876N = 1024
@@ -82,9 +80,10 @@ A = rand!(allocate(backend, Float32, N, R))
8280B = rand! (allocate (backend, Float32, R, M))
8381C = KernelAbstractions. zeros (backend, Float32, N, M)
8482
85- kern = coalesced_matmul_kernel! (backend, (TILE_DIM, TILE_DIM))
83+ workgroupsize= (TILE_DIM, TILE_DIM)
84+ numworkgroups= (cld (size (C,1 ), TILE_DIM), cld (size (C,2 ), TILE_DIM))
8685
87- kern (C, A, B, N, R, M, ndrange = size (C ))
86+ KI . @kernel backend workgroupsize numworkgroups coalesced_matmul_kernel! (C, A, B, N, R, M, Val (TILE_DIM ))
8887KernelAbstractions. synchronize (backend)
8988
9089@test isapprox (A * B, C)
0 commit comments