Skip to content

Commit e16ebe0

Browse files
committed
fix: changing the dtype in random_utils to uint32
1 parent adc5ea9 commit e16ebe0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

algorithmic_efficiency/random_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
2020
# unsigned int), while RandomState.randint only accepts and returns signed ints.
21-
MAX_INT32 = 2**31
22-
MIN_INT32 = -MAX_INT32
21+
MAX_UINT32 = 2**31
22+
MIN_UINT32 = 0
2323

2424
SeedType = Union[int, list, np.ndarray]
2525

@@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
38+
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
44+
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name

0 commit comments

Comments
 (0)