@@ -17,35 +17,37 @@ void aldr_free (struct aldr_s x) {
1717 free (x .leaves_flat );
1818}
1919
20- struct aldr_s aldr_preprocess (int * a , int n ) {
21- // assume k <= 31
22- int m = 0 ;
23- for (int i = 0 ; i < n ; ++ i ) {
20+ struct aldr_s aldr_preprocess (uint32_t * a , uint32_t n ) {
21+ // this algorithm requires that
22+ // 0 < n < (1 << 32) - 1
23+ // 0 < sum(a) < (1 << 31) - 1
24+ uint32_t m = 0 ;
25+ for (uint32_t i = 0 ; i < n ; ++ i ) {
2426 m += a [i ];
2527 }
26- int k = 32 - __builtin_clz (m ) - ( 1 == __builtin_popcount ( m ) );
27- int K = k << 1 ; // depth
28- long long c = (1ll << K ) / m ; // amplification factor
29- long long r = (1ll << K ) % m ; // reject weight
28+ uint32_t k = 32 - __builtin_clz (m - 1 );
29+ uint32_t K = k << 1 ; // depth
30+ uint64_t c = (1ll << K ) / m ; // amplification factor
31+ uint64_t r = (1ll << K ) % m ; // reject weight
3032
31- int num_leaves = __builtin_popcountll (r );
32- for (int i = 0 ; i < n ; ++ i ) {
33+ uint32_t num_leaves = __builtin_popcountll (r );
34+ for (uint32_t i = 0 ; i < n ; ++ i ) {
3335 num_leaves += __builtin_popcountll (c * a [i ]);
3436 }
3537
36- int * breadths = calloc (K + 1 , sizeof (int ));
37- int * leaves_flat = calloc (num_leaves , sizeof (int ));
38+ uint32_t * breadths = calloc (K + 1 , sizeof (* breadths ));
39+ uint32_t * leaves_flat = calloc (num_leaves , sizeof (* leaves_flat ));
3840
39- int location = 0 ;
40- for (int j = 0 ; j <= K ; j ++ ) {
41- long long bit = (1ll << (K - j ));
41+ uint32_t location = 0 ;
42+ for (uint32_t j = 0 ; j <= K ; j ++ ) {
43+ uint64_t bit = (1ll << (K - j ));
4244 if (r & bit ) {
4345 leaves_flat [location ] = 0 ;
4446 ++ breadths [j ];
4547 ++ location ;
4648 }
47- for (int i = 0 ; i < n ; ++ i ) {
48- long long Qi = c * a [i ];
49+ for (uint32_t i = 0 ; i < n ; ++ i ) {
50+ uint64_t Qi = c * a [i ];
4951 if (Qi & bit ) {
5052 leaves_flat [location ] = i + 1 ;
5153 ++ breadths [j ];
@@ -62,14 +64,14 @@ struct aldr_s aldr_preprocess(int* a, int n) {
6264 };
6365}
6466
65- int aldr_sample (struct aldr_s * f ) {
67+ uint32_t aldr_sample (struct aldr_s * f ) {
6668 for (;;) {
67- int depth = 0 ;
68- int location = 0 ;
69- int val = 0 ;
69+ uint32_t depth = 0 ;
70+ uint32_t location = 0 ;
71+ uint32_t val = 0 ;
7072 for (;;) {
7173 if (val < f -> breadths [depth ]) {
72- int ans = f -> leaves_flat [location + val ];
74+ uint32_t ans = f -> leaves_flat [location + val ];
7375 if (ans ) return ans - 1 ;
7476 else break ;
7577 }
0 commit comments