Skip to content

Conversation

balancap
Copy link
Contributor

@balancap balancap commented Aug 14, 2024

In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of FP8 matmuls,
while still generating an optimal fused kernel call including:

  • FP8 inputs scaling;
  • FP8 output scaling & clamping;
  • Non-linearity & bias fusing;
  • Abs-max output capture;

Note: some open questions remain on bias or gelu fusing.

@balancap balancap added the documentation Improvements or additions to documentation label Aug 14, 2024
@balancap balancap self-assigned this Aug 14, 2024
@balancap balancap marked this pull request as draft August 14, 2024 15:31
@balancap balancap force-pushed the fp8-matmul-fusion-tutorial branch from 4a84fcf to 1bbe9e6 Compare August 15, 2024 08:36
@lyprince lyprince force-pushed the fp8-matmul-fusion-tutorial branch 2 times, most recently from 9fef477 to d7713cb Compare September 24, 2024 10:41
In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of FP8 matmuls,
while still generating an optimal fused kernel call including:
* FP8 inputs scaling;
* FP8 output scaling & clamping;
* Non-linearity & bias fusing;
* Abs-max output capture;

Note: some open questions remain on bias or gelu fusing.
@lyprince lyprince force-pushed the fp8-matmul-fusion-tutorial branch from d7713cb to a203df7 Compare September 25, 2024 15:55
@balancap balancap marked this pull request as ready for review September 25, 2024 15:55
@balancap balancap merged commit 4dda0a3 into main Sep 25, 2024
2 checks passed
@balancap balancap deleted the fp8-matmul-fusion-tutorial branch September 25, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant