Skip to content

Commit da98ac8

Browse files
committed
Update unsafe_indices examples to use KI directly
1 parent 42f17d6 commit da98ac8

File tree

2 files changed

+39
-42
lines changed

2 files changed

+39
-42
lines changed

examples/histogram.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@ function create_histogram(input)
1212
return histogram_output
1313
end
1414

15-
# This a 1D histogram kernel where the histogramming happens on shmem
16-
@kernel unsafe_indices = true function histogram_kernel!(histogram_output, input)
17-
gid = @index(Group, Linear)
18-
lid = @index(Local, Linear)
15+
# This a 1D histogram kernel where the histogramming happens on static shmem
16+
function histogram_kernel!(histogram_output, input, ::Val{gs}) where gs
17+
gid = KI.get_group_id().x
18+
lid = KI.get_local_id().x
1919

20-
@uniform gs = prod(@groupsize())
2120
tid = (gid - 1) * gs + lid
22-
@uniform N = length(histogram_output)
21+
N = length(histogram_output)
2322

24-
shared_histogram = @localmem eltype(input) (gs)
23+
shared_histogram = KI.localmemory(eltype(input), gs)
2524

2625
# This will go through all input elements and assign them to a location in
2726
# shmem. Note that if there is not enough shem, we create different shmem
@@ -32,7 +31,7 @@ end
3231

3332
# Setting shared_histogram to 0
3433
@inbounds shared_histogram[lid] = 0
35-
@synchronize()
34+
KI.barrier()
3635

3736
max_element = min_element + gs
3837
if max_element > N
@@ -46,7 +45,7 @@ end
4645
@atomic shared_histogram[bin] += 1
4746
end
4847

49-
@synchronize()
48+
KI.barrier()
5049

5150
if ((lid + min_element - 1) <= N)
5251
@atomic histogram_output[lid + min_element - 1] += shared_histogram[lid]
@@ -59,8 +58,7 @@ end
5958
function histogram!(histogram_output, input, groupsize = 256)
6059
backend = get_backend(histogram_output)
6160
# Need static block size
62-
kernel! = histogram_kernel!(backend, (groupsize,))
63-
kernel!(histogram_output, input, ndrange = size(input))
61+
KI.@kernel backend workgroupsize=groupsize numworkgroups=cld(length(input), groupsize) histogram_kernel!(histogram_output, input, Val(groupsize))
6462
return
6563
end
6664

examples/performant_matmul.jl

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,70 +9,68 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
99
# Metal sometimes supports fewer.
1010
const TILE_DIM = 16
1111

12-
@kernel unsafe_indices = true function coalesced_matmul_kernel!(
13-
output, @Const(input1), @Const(input2), N, R, M,
14-
::Val{BANK} = Val(1),
15-
) where {BANK}
16-
gi, gj = @index(Group, NTuple)
17-
i, j = @index(Local, NTuple)
18-
19-
TILE_DIM = @uniform @groupsize()[1]
12+
function coalesced_matmul_kernel!(
13+
output, input1, input2, N, R, M,
14+
::Val{TDIM}, ::Val{BANK} = Val(1)
15+
) where {TDIM, BANK}
16+
gi, gj, _ = KI.get_group_id()
17+
i, j, _ = KI.get_local_id()
2018

2119
# +1 to avoid bank conflicts on shared memory
22-
tile1 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
23-
tile2 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
20+
tile1 = KI.localmemory(eltype(output), (TDIM + BANK, TDIM))
21+
tile2 = KI.localmemory(eltype(output), (TDIM + BANK, TDIM))
2422

25-
# private variable for tile output
26-
outval = @private eltype(output) 1
27-
@inbounds outval[1] = -zero(eltype(output))
23+
# variable for tile output
24+
outval = -zero(eltype(output))
2825

29-
@uniform N = size(output, 1)
26+
N = size(output, 1)
3027
# number of tiles depends on inner dimension
31-
@uniform NUM_TILES = div(R + TILE_DIM - 1, TILE_DIM)
28+
NUM_TILES = div(R + TDIM - 1, TDIM)
3229

3330
# loop over all tiles needed for this calculation
3431
for t in 0:(NUM_TILES - 1)
3532
# Can't use @index(Global), because we use a smaller ndrange
36-
I = (gi - 1) * TILE_DIM + i
37-
J = (gj - 1) * TILE_DIM + j
33+
I = (gi - 1) * TDIM + i
34+
J = (gj - 1) * TDIM + j
3835

3936
# load inputs into tiles, with bounds checking for non-square matrices
40-
if I <= N && t * TILE_DIM + j <= R
41-
@inbounds tile1[i, j] = input1[I, t * TILE_DIM + j]
37+
if I <= N && t * TDIM + j <= R
38+
@inbounds tile1[i, j] = input1[I, t * TDIM + j]
4239
else
4340
@inbounds tile1[i, j] = 0.0
4441
end
4542
if t * TILE_DIM + i <= R && J <= M
46-
@inbounds tile2[i, j] = input2[t * TILE_DIM + i, J]
43+
@inbounds tile2[i, j] = input2[t * TDIM + i, J]
4744
else
4845
@inbounds tile2[i, j] = 0.0
4946
end
5047

5148
# wait for all tiles to be loaded
52-
@synchronize
49+
KI.barrier()
5350

5451
# get global values again
55-
I = (gi - 1) * TILE_DIM + i
56-
J = (gj - 1) * TILE_DIM + j
52+
I = (gi - 1) * TDIM + i
53+
J = (gj - 1) * TDIM + j
5754

5855
# calculate value of spot in output, use temporary value to allow for vectorization
5956
out = zero(eltype(output))
60-
@simd for k in 1:TILE_DIM
57+
@simd for k in 1:TDIM
6158
@inbounds out += tile1[i, k] * tile2[k, j]
6259
end
63-
outval[1] += out
60+
outval += out
6461

65-
@synchronize
62+
KI.barrier()
6663
end
6764

6865
# get global indices again
69-
I = (gi - 1) * TILE_DIM + i
70-
J = (gj - 1) * TILE_DIM + j
66+
I = (gi - 1) * TDIM + i
67+
J = (gj - 1) * TDIM + j
7168

7269
# save if inbounds
7370
if I <= N && J <= M
74-
@inbounds output[I, J] = outval[1]
71+
@inbounds output[I, J] = outval
7572
end
73+
return nothing
7674
end
7775

7876
N = 1024
@@ -82,9 +80,10 @@ A = rand!(allocate(backend, Float32, N, R))
8280
B = rand!(allocate(backend, Float32, R, M))
8381
C = KernelAbstractions.zeros(backend, Float32, N, M)
8482

85-
kern = coalesced_matmul_kernel!(backend, (TILE_DIM, TILE_DIM))
83+
workgroupsize=(TILE_DIM, TILE_DIM)
84+
numworkgroups=(cld(size(C,1), TILE_DIM), cld(size(C,2), TILE_DIM))
8685

87-
kern(C, A, B, N, R, M, ndrange = size(C))
86+
KI.@kernel backend workgroupsize numworkgroups coalesced_matmul_kernel!(C, A, B, N, R, M, Val(TILE_DIM))
8887
KernelAbstractions.synchronize(backend)
8988

9089
@test isapprox(A * B, C)

0 commit comments

Comments
 (0)