-
|
I have a large JAX codebase which I usually run on my RTX 5080. Recently I ran it on an A100 and it ran slowly. I consulted ChatGPT and it told me I should use Is there a way to get my JAX project to use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 9 replies
-
|
A similar question was asked here: #30106 I don't know of any way to force JAX to default to bfloat16, short of explicitly adding My suggestion would be to write your function such that it always enforces the dtype of the input: most JAX APIs will already do this, and it would allow you to switch between bfloat16 and float32 at runtime without changing the source code. |
Beta Was this translation helpful? Give feedback.
Good to know – that's something we've been thinking about and may add soon.
This would never have worked – it's just an LLM hallucination.
This makes sense, but in general it's only certain operations (typically matmuls) that will benefit from computation in bfloat16, and the precise speedups will depend on the hardware you're using. You may be able to achieve the …