Skip to content

Commit 53f52cb

Browse files
author
jax authors
committed
Merge pull request #9942 from jakevdp:glu-fix
PiperOrigin-RevId: 435453524
2 parents ae631e9 + c762e07 commit 53f52cb

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

jax/_src/nn/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def gelu(x: Array, approximate: bool = True) -> Array:
258258
else:
259259
return jnp.array(x * (lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)
260260

261-
@partial(jax.jit, static_argnames=("glu",))
261+
@partial(jax.jit, static_argnames=("axis",))
262262
def glu(x: Array, axis: int = -1) -> Array:
263263
"""Gated linear unit activation function.
264264

tests/nn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def testEluValue(self):
8181
self.assertAllClose(val, 1e4, check_dtypes=False)
8282

8383
def testGluValue(self):
84-
val = nn.glu(jnp.array([1.0, 0.0]))
84+
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
8585
self.assertAllClose(val, jnp.array([0.5]))
8686

8787
@parameterized.parameters(False, True)

0 commit comments

Comments
 (0)