Skip to content

Commit 446d487

Browse files
authored
Update bitonic_sort_shader.comp.hlsl
1 parent 800802b commit 446d487

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

13_BitonicSort/app_resources/bitonic_sort_shader.comp.hlsl

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using BitonicSortConfig = workgroup::bitonic_sort::bitonic_sort_config<ElementsP
1010

1111
NBL_CONSTEXPR uint32_t WorkgroupSize = BitonicSortConfig::WorkgroupSize;
1212

13-
groupshared uint32_t sharedmem[2 * WorkgroupSize];
13+
groupshared uint32_t sharedmem[BitonicSortConfig::SharedmemDWORDs];
1414

1515
uint32_t3 glsl::gl_WorkGroupSize() { return uint32_t3(uint32_t(BitonicSortConfig::WorkgroupSize), 1, 1); }
1616

@@ -19,20 +19,21 @@ struct SharedMemoryAccessor
1919
template <typename AccessType, typename IndexType>
2020
void set(IndexType idx, AccessType value)
2121
{
22-
sharedmem[idx] = value;
22+
sharedmem[idx * 2] = value.first;
23+
sharedmem[idx * 2 + 1] = value.second;
2324
}
2425

2526
template <typename AccessType, typename IndexType>
2627
void get(IndexType idx, NBL_REF_ARG(AccessType) value)
2728
{
28-
value = sharedmem[idx];
29+
value.first = sharedmem[idx * 2];
30+
value.second = sharedmem[idx * 2 + 1];
2931
}
3032

3133
void workgroupExecutionAndMemoryBarrier()
3234
{
3335
glsl::barrier();
3436
}
35-
3637
};
3738

3839
struct Accessor
@@ -47,13 +48,17 @@ struct Accessor
4748
template <typename AccessType, typename IndexType>
4849
void get(const IndexType index, NBL_REF_ARG(AccessType) value)
4950
{
50-
value = vk::RawBufferLoad<AccessType>(address + index * sizeof(AccessType));
51+
const uint64_t offset = address + index * sizeof(AccessType);
52+
value.first = vk::RawBufferLoad<uint32_t>(offset);
53+
value.second = vk::RawBufferLoad<uint32_t>(offset + sizeof(uint32_t));
5154
}
5255

5356
template <typename AccessType, typename IndexType>
5457
void set(const IndexType index, const AccessType value)
5558
{
56-
vk::RawBufferStore<AccessType>(address + index * sizeof(AccessType), value);
59+
const uint64_t offset = address + index * sizeof(AccessType);
60+
vk::RawBufferStore<uint32_t>(offset, value.first);
61+
vk::RawBufferStore<uint32_t>(offset + sizeof(uint32_t), value.second);
5762
}
5863

5964
uint64_t address;
@@ -66,22 +71,6 @@ void main()
6671
Accessor accessor = Accessor::create(pushConstants.deviceBufferAddress);
6772
SharedMemoryAccessor sharedmemAccessor;
6873

69-
const uint32_t threadID = glsl::gl_LocalInvocationID().x;
70-
71-
// Each thread handles 2 ADJACENT elements: lo and hi
72-
// Following bitonic sort pattern: thread i handles elements [2*i] and [2*i + 1]
73-
const uint32_t loIdx = threadID * 2;
74-
const uint32_t hiIdx = threadID * 2 + 1;
75-
76-
uint32_t loKey, hiKey;
77-
accessor.get(loIdx, loKey);
78-
accessor.get(hiIdx, hiKey);
79-
80-
uint32_t loVal = loIdx;
81-
uint32_t hiVal = hiIdx;
82-
83-
workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor, loKey, hiKey, loVal, hiVal);
84-
85-
accessor.set(loIdx, loKey);
86-
accessor.set(hiIdx, hiKey);
74+
// The sort handles load/store internally
75+
workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor);
8776
}

0 commit comments

Comments
 (0)