Skip to content

Commit 5d7f639

Browse files
zhangqiaorjcjax authors
authored andcommitted
Add small and big matmul to api_benchmarks.
name cpu/op jit_small_matmul 2.96µs ± 2% jit_big_matmul 22.1µs ±21% name time/op jit_small_matmul 2.96µs ± 2% jit_big_matmul 22.7µs ±21% PiperOrigin-RevId: 435453853
1 parent 53f52cb commit 5d7f639

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

benchmarks/api_benchmark.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,30 @@ def jit_simple(state):
118118
f(a, b).block_until_ready()
119119

120120

121+
@google_benchmark.register
122+
def jit_small_matmul(state):
123+
x = np.random.uniform(size=(2, 2)).astype(np.float32)
124+
x = jax.device_put(x)
125+
126+
f = jax.jit(lambda x: jnp.dot(x, x))
127+
f(x).block_until_ready()
128+
129+
while state:
130+
f(x).block_until_ready()
131+
132+
133+
@google_benchmark.register
134+
def jit_big_matmul(state):
135+
x = np.random.uniform(size=(100, 100)).astype(np.float32)
136+
x = jax.device_put(x)
137+
138+
f = jax.jit(lambda x: jnp.dot(x, x))
139+
f(x).block_until_ready()
140+
141+
while state:
142+
f(x).block_until_ready()
143+
144+
121145
def jit_simple_many_args_dispatch(n, state):
122146
args = [jax.device_put(i) for i in range(n)]
123147
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))

0 commit comments

Comments
 (0)