1- #include "nbl/builtin/hlsl/bda/bda_accessor.hlsl"
1+ #include "common.hlsl"
2+ #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
3+ #include "nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl"
24
3- struct BitonicPushData
5+ [[vk::push_constant]] PushConstantData pushConstants;
6+
7+ using namespace nbl::hlsl;
8+
9+ using BitonicSortConfig = workgroup::bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, uint32_t, uint32_t, less <uint32_t> >;
10+
11+ NBL_CONSTEXPR uint32_t WorkgroupSize = BitonicSortConfig::WorkgroupSize;
12+
13+ groupshared uint32_t sharedmem[2 * WorkgroupSize];
14+
15+ uint32_t3 glsl::gl_WorkGroupSize () { return uint32_t3 (uint32_t (BitonicSortConfig::WorkgroupSize), 1 , 1 ); }
16+
17+ struct SharedMemoryAccessor
418{
5- uint64_t inputKeyAddress;
6- uint64_t inputValueAddress;
7- uint64_t outputKeyAddress;
8- uint64_t outputValueAddress;
9- uint32_t dataElementCount;
19+ template <typename AccessType, typename IndexType>
20+ void set (IndexType idx, AccessType value)
21+ {
22+ sharedmem[idx] = value;
23+ }
24+
25+ template <typename AccessType, typename IndexType>
26+ void get (IndexType idx, NBL_REF_ARG (AccessType) value)
27+ {
28+ value = sharedmem[idx];
29+ }
30+
31+ void workgroupExecutionAndMemoryBarrier ()
32+ {
33+ glsl::barrier ();
34+ }
35+
1036};
1137
12- using namespace nbl::hlsl;
38+ struct Accessor
39+ {
40+ static Accessor create (const uint64_t address)
41+ {
42+ Accessor accessor;
43+ accessor.address = address;
44+ return accessor;
45+ }
1346
14- [[vk::push_constant]] BitonicPushData pushData;
47+ template <typename AccessType, typename IndexType>
48+ void get (const IndexType index, NBL_REF_ARG (AccessType) value)
49+ {
50+ value = vk::RawBufferLoad<AccessType>(address + index * sizeof (AccessType));
51+ }
1552
16- using DataPtr = bda::__ptr<uint32_t>;
17- using DataAccessor = BdaAccessor<uint32_t>;
53+ template <typename AccessType, typename IndexType>
54+ void set (const IndexType index, const AccessType value)
55+ {
56+ vk::RawBufferStore<AccessType>(address + index * sizeof (AccessType), value);
57+ }
1858
19- groupshared uint32_t sharedKeys[ElementCount] ;
20- groupshared uint32_t sharedValues[ElementCount] ;
59+ uint64_t address ;
60+ } ;
2161
22- [numthreads (WorkgroupSize, 1 , 1 )]
62+ [numthreads (BitonicSortConfig:: WorkgroupSize, 1 , 1 )]
2363[shader ("compute" )]
24- void main (uint32_t3 dispatchId : SV_DispatchThreadID , uint32_t3 localId : SV_GroupThreadID )
64+ void main ()
2565{
26- const uint32_t threadId = localId.x;
27- const uint32_t dataSize = pushData.dataElementCount;
28-
29- DataAccessor inputKeys = DataAccessor::create (DataPtr::create (pushData.inputKeyAddress));
30- DataAccessor inputValues = DataAccessor::create (DataPtr::create (pushData.inputValueAddress));
31-
32- for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
33- {
34- inputKeys.get (i, sharedKeys[i]);
35- inputValues.get (i, sharedValues[i]);
36- }
37-
38- // Synchronize all threads after loading
39- GroupMemoryBarrierWithGroupSync ();
40-
41-
42- for (uint32_t stage = 0 ; stage < Log2ElementCount; stage++)
43- {
44- for (uint32_t pass = 0 ; pass <= stage; pass ++)
45- {
46- const uint32_t compareDistance = 1 << (stage - pass );
47-
48- for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
49- {
50- const uint32_t partnerId = i ^ compareDistance;
51-
52- if (partnerId >= dataSize)
53- continue ;
54-
55- const uint32_t waveSize = WaveGetLaneCount ();
56- const uint32_t myWaveId = i / waveSize;
57- const uint32_t partnerWaveId = partnerId / waveSize;
58- const bool sameWave = (myWaveId == partnerWaveId);
59-
60- uint32_t myKey, myValue, partnerKey, partnerValue;
61- [branch]
62- if (sameWave && compareDistance < waveSize)
63- {
64- // WAVE INTRINSIC
65- myKey = sharedKeys[i];
66- myValue = sharedValues[i];
67-
68- const uint32_t partnerLane = partnerId % waveSize;
69- partnerKey = WaveReadLaneAt (myKey, partnerLane);
70- partnerValue = WaveReadLaneAt (myValue, partnerLane);
71- }
72- else
73- {
74- // SHARED MEM
75- myKey = sharedKeys[i];
76- myValue = sharedValues[i];
77- partnerKey = sharedKeys[partnerId];
78- partnerValue = sharedValues[partnerId];
79- }
80-
81- const uint32_t sequenceSize = 1 << (stage + 1 );
82- const uint32_t sequenceIndex = i / sequenceSize;
83- const bool sequenceAscending = (sequenceIndex % 2 ) == 0 ;
84- const bool ascending = true ;
85- const bool finalDirection = sequenceAscending == ascending;
86-
87- const bool swap = (myKey > partnerKey) == finalDirection;
88-
89- // WORKGROUP COORDINATION: Only lower-indexed element writes both
90- if (i < partnerId && swap)
91- {
92- sharedKeys[i] = partnerKey;
93- sharedKeys[partnerId] = myKey;
94- sharedValues[i] = partnerValue;
95- sharedValues[partnerId] = myValue;
96- }
97- }
98-
99- GroupMemoryBarrierWithGroupSync ();
100- }
101- }
102-
103-
104- DataAccessor outputKeys = DataAccessor::create (DataPtr::create (pushData.outputKeyAddress));
105- DataAccessor outputValues = DataAccessor::create (DataPtr::create (pushData.outputValueAddress));
106-
107- for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
108- {
109- outputKeys.set (i, sharedKeys[i]);
110- outputValues.set (i, sharedValues[i]);
111- }
112- }
66+ Accessor accessor = Accessor::create (pushConstants.deviceBufferAddress);
67+ SharedMemoryAccessor sharedmemAccessor;
68+
69+ const uint32_t threadID = glsl::gl_LocalInvocationID ().x;
70+
71+ // Each thread handles 2 ADJACENT elements: lo and hi
72+ // Following bitonic sort pattern: thread i handles elements [2*i] and [2*i + 1]
73+ const uint32_t loIdx = threadID * 2 ;
74+ const uint32_t hiIdx = threadID * 2 + 1 ;
75+
76+ uint32_t loKey, hiKey;
77+ accessor.get (loIdx, loKey);
78+ accessor.get (hiIdx, hiKey);
79+
80+ uint32_t loVal = loIdx;
81+ uint32_t hiVal = hiIdx;
82+
83+ workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor, loKey, hiKey, loVal, hiVal);
84+
85+ accessor.set (loIdx, loKey);
86+ accessor.set (hiIdx, hiKey);
87+ }
0 commit comments