Skip to content

Commit 07dad8e

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Fix 64-bit multiplication emulation in Philox kernel
PiperOrigin-RevId: 819807662
1 parent 4f06e7a commit 07dad8e

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

jax/experimental/pallas/ops/tpu/random/philox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
4646
cross_xy = xhi * ylo
4747
cross_yx = xlo * yhi
4848
carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16)
49-
return xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16), xy_lo
49+
result_hi = xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16)
50+
result_lo = (carry << 16) + (xy_lo & 0xffff)
51+
return result_hi, result_lo
5052

5153

5254
def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10):

tests/pallas/tpu_pallas_random_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,19 @@ def setUp(self):
376376
self.skipTest("Need TPU devices")
377377
super().setUp()
378378

379+
@parameterized.product(
380+
x=[0x1, 0x10000, 0xabcdef],
381+
y=[0x1, 0x10000, 0xabcdef],
382+
)
383+
def test_mul_hi_lo(self, x, y):
384+
x = jnp.uint32(x)
385+
y = jnp.uint32(y)
386+
hi, lo = philox.mul32_hi_lo(x, y)
387+
with jax.enable_x64():
388+
result = (hi.astype(jnp.uint64) << 32) + lo.astype(jnp.uint64)
389+
ref = x.astype(jnp.uint64) * y.astype(jnp.uint64)
390+
self.assertEqual(result, ref)
391+
379392
@parameterized.parameters(
380393
((512, 512),),
381394
((137, 275),), # Non block-aligned shape

0 commit comments

Comments
 (0)