66
77#include  "nbl/builtin/hlsl/cpp_compat.hlsl" 
88#include  "nbl/builtin/hlsl/tuple.hlsl" 
9+ #include  "nbl/builtin/hlsl/mpl.hlsl" 
910
1011namespace  nbl 
1112{
@@ -19,23 +20,37 @@ namespace impl
1920template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2>
2021struct  virtual_wg_size_log2
2122{
22-     NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
23-     NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
23+     #define  DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
24+     #define  DEFINE_VIRTUAL_WG_T (ID) ID
25+     #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
26+     #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
27+     #include  "impl/virtual_wg_size_def.hlsl" 
28+     #undef  DEFINE_COND_VAL
29+     #undef  DEFINE_MPL_MAX_V
30+     #undef  DEFINE_VIRTUAL_WG_T
31+     #undef  DEFINE_ASSIGN
32+     
33+     // must have at least enough level 0 outputs to feed a single subgroup 
2434    static_assert (WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize" );
2535    static_assert (WorkgroupSizeLog2<=SubgroupSizeLog2*3 +4 , "WorkgroupSize cannot be larger than (SubgroupSize^3)*16" );
26- 
27-     NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2 +2 ),uint16_t,3 ,2 >::value,1 >::value;
28-     NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, SubgroupSizeLog2*levels, WorkgroupSizeLog2>;
29-     // must have at least enough level 0 outputs to feed a single subgroup 
3036};
3137
3238template<class  VirtualWorkgroup, uint16_t BaseItemsPerInvocation>
3339struct  items_per_invocation
3440{
35-     NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v<int16_t,VirtualWorkgroup::WorkgroupSizeLog2-VirtualWorkgroup::SubgroupSizeLog2*VirtualWorkgroup::levels,0 >;
36-     NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
37-     NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t (0x1u) << conditional_value<VirtualWorkgroup::levels==3 , uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2 >, ItemsPerInvocationProductLog2>::value;
38-     NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t (0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2 ,0 >;
41+     #define  DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
42+     #define  DEFINE_VIRTUAL_WG_T (ID) VirtualWorkgroup::ID
43+     #define  DEFINE_ITEMS_INVOC_T (ID) ID
44+     #define  DEFINE_MPL_MIN_V (TYPE,ARG1,ARG2) mpl::min_v<TYPE, ARG1, ARG2>
45+     #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
46+     #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
47+     #include  "impl/items_per_invoc_def.hlsl" 
48+     #undef  DEFINE_COND_VAL
49+     #undef  DEFINE_MPL_MAX_V
50+     #undef  DEFINE_MPL_MIN_V
51+     #undef  DEFINE_ITEMS_INVOC_T
52+     #undef  DEFINE_VIRTUAL_WG_T
53+     #undef  DEFINE_ASSIGN
3954
4055    using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >;
4156};
@@ -44,47 +59,35 @@ struct items_per_invocation
4459template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
4560struct  ArithmeticConfiguration
4661{
47-     NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
48-     NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t (0x1u) << WorkgroupSizeLog2;
49-     NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
50-     NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t (0x1u) << SubgroupSizeLog2;
51- 
52-     using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
53-     NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
54-     NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t (0x1u) << virtual_wg_t::value;
55-     static_assert (VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
56- 
62+     using virtual_wg_t = impl::virtual_wg_size_log2<_WorkgroupSizeLog2, _SubgroupSizeLog2>;
5763    using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
5864    using ItemsPerInvocation = typename items_per_invoc_t::ItemsPerInvocation;
59-     NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = tuple_element<0 ,ItemsPerInvocation>::type::value;
60-     NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = tuple_element<1 ,ItemsPerInvocation>::type::value;
61-     NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = tuple_element<2 ,ItemsPerInvocation>::type::value;
62-     static_assert (ItemsPerInvocation_2<=4 , "4 level scan would have been needed with this config!" );
6365
64-     NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3 ,uint16_t,
65-         mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
66-         SubgroupSize*ItemsPerInvocation_1>::value;
67-     NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize*ItemsPerInvocation_2,0 >::value;
68-     NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualInvocationsAtLevel1 = LevelInputCount_1 / ItemsPerInvocation_1;
66+     #define  DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
67+     #define  DEFINE_VIRTUAL_WG_T (ID) virtual_wg_t::ID
68+     #define  DEFINE_ITEMS_INVOC_T (ID) items_per_invoc_t::ID
69+     #define  DEFINE_CONFIG_T (ID) ID
70+     #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
71+     #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
72+     #include  "impl/arithmetic_config_def.hlsl" 
73+     #undef  DEFINE_COND_VAL
74+     #undef  DEFINE_MPL_MAX_V
75+     #undef  DEFINE_CONFIG_T
76+     #undef  DEFINE_ITEMS_INVOC_T
77+     #undef  DEFINE_VIRTUAL_WG_T
78+     #undef  DEFINE_ASSIGN
6979
70-     NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize-1 ,0 >::value;
71-     NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3 ,uint16_t,VirtualInvocationsAtLevel1,SubgroupSize>::value + __padding;
72-     NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize,0 >::value;
7380    using ChannelStride = tuple<integral_constant<uint16_t,__padding>,integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >; // we don't use stride 0 
7481
75-     // user specified the shared mem size of Scalars 
76-     NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1 ,uint16_t,
77-         0 ,
78-         conditional_value<LevelCount==3 ,uint16_t,
79-             LevelInputCount_2+(SubgroupSize*ItemsPerInvocation_1)-1 ,
80-             0 
81-             >::value + LevelInputCount_1
82-         >::value;
82+     static_assert (VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
83+     static_assert (ItemsPerInvocation_2<=4 , "4 level scan would have been needed with this config!" );
8384
85+ #ifdef  __HLSL_VERSION
8486    static  bool  electLast ()
8587    {
8688        return  glsl::gl_SubgroupInvocationID ()==SubgroupSize-1 ;
8789    }
90+ #endif 
8891
8992    // gets a subgroupID as if each workgroup has (VirtualWorkgroupSize/SubgroupSize) subgroups 
9093    // each subgroup does work (VirtualWorkgroupSize/WorkgroupSize) times, the index denoted by workgroupInVirtualIndex 
@@ -140,6 +143,88 @@ struct ArithmeticConfiguration
140143    }
141144};
142145
146+ #ifndef  __HLSL_VERSION
147+ namespace  impl
148+ {
149+ struct  SVirtualWGSizeLog2
150+ {
151+     static  SVirtualWGSizeLog2 create (const  uint16_t _WorkgroupSizeLog2, const  uint16_t _SubgroupSizeLog2)
152+     {
153+         SVirtualWGSizeLog2 retval;
154+         #define  DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
155+         #define  DEFINE_VIRTUAL_WG_T (ID) retval.ID
156+         #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
157+         #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
158+         #include  "impl/virtual_wg_size_def.hlsl" 
159+         #undef  DEFINE_COND_VAL
160+         #undef  DEFINE_MPL_MAX_V
161+         #undef  DEFINE_VIRTUAL_WG_T
162+         #undef  DEFINE_ASSIGN
163+         return  retval;
164+     }
165+ 
166+     #define  DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
167+     #include  "impl/virtual_wg_size_def.hlsl" 
168+     #undef  DEFINE_ASSIGN
169+ };
170+ 
171+ struct  SItemsPerInvoc
172+ {
173+     static  SItemsPerInvoc create (const  SVirtualWGSizeLog2 virtualWgSizeLog2, const  uint16_t BaseItemsPerInvocation)
174+     {
175+         SItemsPerInvoc retval;
176+         #define  DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
177+         #define  DEFINE_VIRTUAL_WG_T (ID) virtualWgSizeLog2.ID
178+         #define  DEFINE_ITEMS_INVOC_T (ID) retval.ID
179+         #define  DEFINE_MPL_MIN_V (TYPE,ARG1,ARG2) hlsl::min <TYPE>(ARG1, ARG2)
180+         #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
181+         #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
182+         #include  "impl/items_per_invoc_def.hlsl" 
183+         #undef  DEFINE_COND_VAL
184+         #undef  DEFINE_MPL_MAX_V
185+         #undef  DEFINE_MPL_MIN_V
186+         #undef  DEFINE_ITEMS_INVOC_T
187+         #undef  DEFINE_VIRTUAL_WG_T
188+         #undef  DEFINE_ASSIGN
189+         return  retval;
190+     }
191+ 
192+     #define  DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
193+     #include  "impl/items_per_invoc_def.hlsl" 
194+     #undef  DEFINE_ASSIGN
195+ };
196+ }
197+ 
198+ struct  SArithmeticConfiguration
199+ {
200+     static  SArithmeticConfiguration create (const  uint16_t _WorkgroupSizeLog2, const  uint16_t _SubgroupSizeLog2, const  uint16_t _ItemsPerInvocation)
201+     {
202+         impl::SVirtualWGSizeLog2 virtualWgSizeLog2 = impl::SVirtualWGSizeLog2::create (_WorkgroupSizeLog2, _SubgroupSizeLog2);
203+         impl::SItemsPerInvoc itemsPerInvoc = impl::SItemsPerInvoc::create (virtualWgSizeLog2, _ItemsPerInvocation);
204+ 
205+         SArithmeticConfiguration retval;
206+         #define  DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
207+         #define  DEFINE_VIRTUAL_WG_T (ID) virtualWgSizeLog2.ID
208+         #define  DEFINE_ITEMS_INVOC_T (ID) itemsPerInvoc.ID
209+         #define  DEFINE_CONFIG_T (ID) retval.ID
210+         #define  DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
211+         #define  DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
212+         #include  "impl/arithmetic_config_def.hlsl" 
213+         #undef  DEFINE_COND_VAL
214+         #undef  DEFINE_MPL_MAX_V
215+         #undef  DEFINE_CONFIG_T
216+         #undef  DEFINE_ITEMS_INVOC_T
217+         #undef  DEFINE_VIRTUAL_WG_T
218+         #undef  DEFINE_ASSIGN
219+         return  retval;
220+     }
221+ 
222+     #define  DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
223+     #include  "impl/arithmetic_config_def.hlsl" 
224+     #undef  DEFINE_ASSIGN
225+ };
226+ #endif 
227+ 
143228template<class  T>
144229struct  is_configuration : bool_constant<false > {};
145230
0 commit comments