Skip to content

Commit 547e518

Browse files
authored
Update bitonic_sort_shader.comp.hlsl
1 parent fd346a0 commit 547e518

File tree

1 file changed

+77
-102
lines changed

1 file changed

+77
-102
lines changed
Lines changed: 77 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,87 @@
1-
#include "nbl/builtin/hlsl/bda/bda_accessor.hlsl"
1+
#include "common.hlsl"
2+
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
3+
#include "nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl"
24

3-
struct BitonicPushData
5+
[[vk::push_constant]] PushConstantData pushConstants;
6+
7+
using namespace nbl::hlsl;
8+
9+
using BitonicSortConfig = workgroup::bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, uint32_t, uint32_t, less<uint32_t> >;
10+
11+
NBL_CONSTEXPR uint32_t WorkgroupSize = BitonicSortConfig::WorkgroupSize;
12+
13+
groupshared uint32_t sharedmem[2 * WorkgroupSize];
14+
15+
uint32_t3 glsl::gl_WorkGroupSize() { return uint32_t3(uint32_t(BitonicSortConfig::WorkgroupSize), 1, 1); }
16+
17+
struct SharedMemoryAccessor
418
{
5-
uint64_t inputKeyAddress;
6-
uint64_t inputValueAddress;
7-
uint64_t outputKeyAddress;
8-
uint64_t outputValueAddress;
9-
uint32_t dataElementCount;
19+
template <typename AccessType, typename IndexType>
20+
void set(IndexType idx, AccessType value)
21+
{
22+
sharedmem[idx] = value;
23+
}
24+
25+
template <typename AccessType, typename IndexType>
26+
void get(IndexType idx, NBL_REF_ARG(AccessType) value)
27+
{
28+
value = sharedmem[idx];
29+
}
30+
31+
void workgroupExecutionAndMemoryBarrier()
32+
{
33+
glsl::barrier();
34+
}
35+
1036
};
1137

12-
using namespace nbl::hlsl;
38+
struct Accessor
39+
{
40+
static Accessor create(const uint64_t address)
41+
{
42+
Accessor accessor;
43+
accessor.address = address;
44+
return accessor;
45+
}
1346

14-
[[vk::push_constant]] BitonicPushData pushData;
47+
template <typename AccessType, typename IndexType>
48+
void get(const IndexType index, NBL_REF_ARG(AccessType) value)
49+
{
50+
value = vk::RawBufferLoad<AccessType>(address + index * sizeof(AccessType));
51+
}
1552

16-
using DataPtr = bda::__ptr<uint32_t>;
17-
using DataAccessor = BdaAccessor<uint32_t>;
53+
template <typename AccessType, typename IndexType>
54+
void set(const IndexType index, const AccessType value)
55+
{
56+
vk::RawBufferStore<AccessType>(address + index * sizeof(AccessType), value);
57+
}
1858

19-
groupshared uint32_t sharedKeys[ElementCount];
20-
groupshared uint32_t sharedValues[ElementCount];
59+
uint64_t address;
60+
};
2161

22-
[numthreads(WorkgroupSize, 1, 1)]
62+
[numthreads(BitonicSortConfig::WorkgroupSize, 1, 1)]
2363
[shader("compute")]
24-
void main(uint32_t3 dispatchId : SV_DispatchThreadID, uint32_t3 localId : SV_GroupThreadID)
64+
void main()
2565
{
26-
const uint32_t threadId = localId.x;
27-
const uint32_t dataSize = pushData.dataElementCount;
28-
29-
DataAccessor inputKeys = DataAccessor::create(DataPtr::create(pushData.inputKeyAddress));
30-
DataAccessor inputValues = DataAccessor::create(DataPtr::create(pushData.inputValueAddress));
31-
32-
for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
33-
{
34-
inputKeys.get(i, sharedKeys[i]);
35-
inputValues.get(i, sharedValues[i]);
36-
}
37-
38-
// Synchronize all threads after loading
39-
GroupMemoryBarrierWithGroupSync();
40-
41-
42-
for (uint32_t stage = 0; stage < Log2ElementCount; stage++)
43-
{
44-
for (uint32_t pass = 0; pass <= stage; pass++)
45-
{
46-
const uint32_t compareDistance = 1 << (stage - pass);
47-
48-
for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
49-
{
50-
const uint32_t partnerId = i ^ compareDistance;
51-
52-
if (partnerId >= dataSize)
53-
continue;
54-
55-
const uint32_t waveSize = WaveGetLaneCount();
56-
const uint32_t myWaveId = i / waveSize;
57-
const uint32_t partnerWaveId = partnerId / waveSize;
58-
const bool sameWave = (myWaveId == partnerWaveId);
59-
60-
uint32_t myKey, myValue, partnerKey, partnerValue;
61-
[branch]
62-
if (sameWave && compareDistance < waveSize)
63-
{
64-
// WAVE INTRINSIC
65-
myKey = sharedKeys[i];
66-
myValue = sharedValues[i];
67-
68-
const uint32_t partnerLane = partnerId % waveSize;
69-
partnerKey = WaveReadLaneAt(myKey, partnerLane);
70-
partnerValue = WaveReadLaneAt(myValue, partnerLane);
71-
}
72-
else
73-
{
74-
// SHARED MEM
75-
myKey = sharedKeys[i];
76-
myValue = sharedValues[i];
77-
partnerKey = sharedKeys[partnerId];
78-
partnerValue = sharedValues[partnerId];
79-
}
80-
81-
const uint32_t sequenceSize = 1 << (stage + 1);
82-
const uint32_t sequenceIndex = i / sequenceSize;
83-
const bool sequenceAscending = (sequenceIndex % 2) == 0;
84-
const bool ascending = true;
85-
const bool finalDirection = sequenceAscending == ascending;
86-
87-
const bool swap = (myKey > partnerKey) == finalDirection;
88-
89-
// WORKGROUP COORDINATION: Only lower-indexed element writes both
90-
if (i < partnerId && swap)
91-
{
92-
sharedKeys[i] = partnerKey;
93-
sharedKeys[partnerId] = myKey;
94-
sharedValues[i] = partnerValue;
95-
sharedValues[partnerId] = myValue;
96-
}
97-
}
98-
99-
GroupMemoryBarrierWithGroupSync();
100-
}
101-
}
102-
103-
104-
DataAccessor outputKeys = DataAccessor::create(DataPtr::create(pushData.outputKeyAddress));
105-
DataAccessor outputValues = DataAccessor::create(DataPtr::create(pushData.outputValueAddress));
106-
107-
for (uint32_t i = threadId; i < dataSize; i += WorkgroupSize)
108-
{
109-
outputKeys.set(i, sharedKeys[i]);
110-
outputValues.set(i, sharedValues[i]);
111-
}
112-
}
66+
Accessor accessor = Accessor::create(pushConstants.deviceBufferAddress);
67+
SharedMemoryAccessor sharedmemAccessor;
68+
69+
const uint32_t threadID = glsl::gl_LocalInvocationID().x;
70+
71+
// Each thread handles 2 ADJACENT elements: lo and hi
72+
// Following bitonic sort pattern: thread i handles elements [2*i] and [2*i + 1]
73+
const uint32_t loIdx = threadID * 2;
74+
const uint32_t hiIdx = threadID * 2 + 1;
75+
76+
uint32_t loKey, hiKey;
77+
accessor.get(loIdx, loKey);
78+
accessor.get(hiIdx, hiKey);
79+
80+
uint32_t loVal = loIdx;
81+
uint32_t hiVal = hiIdx;
82+
83+
workgroup::BitonicSort<BitonicSortConfig>::template __call<Accessor, SharedMemoryAccessor>(accessor, sharedmemAccessor, loKey, hiKey, loVal, hiVal);
84+
85+
accessor.set(loIdx, loKey);
86+
accessor.set(hiIdx, hiKey);
87+
}

0 commit comments

Comments
 (0)