@@ -10,7 +10,7 @@ using BitonicSortConfig = workgroup::bitonic_sort::bitonic_sort_config<ElementsP
1010
1111NBL_CONSTEXPR uint32_t WorkgroupSize = BitonicSortConfig::WorkgroupSize;
1212
13- groupshared uint32_t sharedmem[2 * WorkgroupSize ];
13+ groupshared uint32_t sharedmem[BitonicSortConfig::SharedmemDWORDs ];
1414
1515uint32_t3 glsl::gl_WorkGroupSize () { return uint32_t3 (uint32_t (BitonicSortConfig::WorkgroupSize), 1 , 1 ); }
1616
@@ -19,20 +19,21 @@ struct SharedMemoryAccessor
1919 template <typename AccessType, typename IndexType>
2020 void set (IndexType idx, AccessType value)
2121 {
22- sharedmem[idx] = value;
22+ sharedmem[idx * 2 ] = value.first;
23+ sharedmem[idx * 2 + 1 ] = value.second;
2324 }
2425
2526 template <typename AccessType, typename IndexType>
2627 void get (IndexType idx, NBL_REF_ARG (AccessType) value)
2728 {
28- value = sharedmem[idx];
29+ value.first = sharedmem[idx * 2 ];
30+ value.second = sharedmem[idx * 2 + 1 ];
2931 }
3032
3133 void workgroupExecutionAndMemoryBarrier ()
3234 {
3335 glsl::barrier ();
3436 }
35-
3637};
3738
3839struct Accessor
@@ -47,13 +48,17 @@ struct Accessor
4748 template <typename AccessType, typename IndexType>
4849 void get (const IndexType index, NBL_REF_ARG (AccessType) value)
4950 {
50- value = vk::RawBufferLoad<AccessType>(address + index * sizeof (AccessType));
51+ const uint64_t offset = address + index * sizeof (AccessType);
52+ value.first = vk::RawBufferLoad<uint32_t>(offset);
53+ value.second = vk::RawBufferLoad<uint32_t>(offset + sizeof (uint32_t));
5154 }
5255
5356 template <typename AccessType, typename IndexType>
5457 void set (const IndexType index, const AccessType value)
5558 {
56- vk::RawBufferStore<AccessType>(address + index * sizeof (AccessType), value);
59+ const uint64_t offset = address + index * sizeof (AccessType);
60+ vk::RawBufferStore<uint32_t>(offset, value.first);
61+ vk::RawBufferStore<uint32_t>(offset + sizeof (uint32_t), value.second);
5762 }
5863
5964 uint64_t address;
@@ -66,22 +71,6 @@ void main()
6671 Accessor accessor = Accessor::create (pushConstants.deviceBufferAddress);
6772 SharedMemoryAccessor sharedmemAccessor;
6873
69- const uint32_t threadID = glsl::gl_LocalInvocationID ().x;
70-
71- // Each thread handles 2 ADJACENT elements: lo and hi
72- // Following bitonic sort pattern: thread i handles elements [2*i] and [2*i + 1]
73- const uint32_t loIdx = threadID * 2 ;
74- const uint32_t hiIdx = threadID * 2 + 1 ;
75-
76- uint32_t loKey, hiKey;
77- accessor.get (loIdx, loKey);
78- accessor.get (hiIdx, hiKey);
79-
80- uint32_t loVal = loIdx;
81- uint32_t hiVal = hiIdx;
82-
83- workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor, loKey, hiKey, loVal, hiVal);
84-
85- accessor.set (loIdx, loKey);
86- accessor.set (hiIdx, hiKey);
74+ // The sort handles load/store internally
75+ workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor);
8776}
0 commit comments