Skip to content
Discussion options

You must be logged in to vote

I think that making bfloat16 the default dtype for floating point numbers would achieve most of what I want.

Good to know – that's something we've been thinking about and may add soon.

and config.update("jax_default_dtype_bits", 16) which apparently doesn't work anymore.

This would never have worked – it's just an LLM hallucination.

once I switched it to bfloat16, the runtime on the A100 was 2-3x faster than the 5080, which is the kind of speedup I expected.

This makes sense, but in general it's only certain operations (typically matmuls) that will benefit from computation in bfloat16, and the precise speedups will depend on the hardware you're using. You may be able to achieve the …

Replies: 1 comment 9 replies

Comment options

You must be logged in to vote
9 replies
@cool-RR
Comment options

@jakevdp
Comment options

@cool-RR
Comment options

@jakevdp
Comment options

Answer selected by cool-RR
@cool-RR
Comment options

@cool-RR
Comment options

@cool-RR
Comment options

@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants