Roadmap for Quantization in NNX #4655
Unanswered
liamclarkza
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Flax team,
I'd like to start a discussion about potential quantization features for Flax NNX (including, but not limited to FP8-based quantization). As models grow increasingly larger, quantization capabilities could offer significant benefits for both training and inference scenarios. Looking at the landscape of quantization, several features might be valuable additions to NNX:
FP8 GEMM quantization, like linen provides within fp8_ops, would be useful to bring into NNX.
Activation and gradient quantization would allow one to reduce memory usage for stored activations, which could be particularly effective when combined with rematerialization. I have run into some issues with this regarding limitations in Jax where it doesn’t currently seem possible to use different primal and tangent datatypes easily when using transforms like scan (which would be nice to leverage higher precision gradients or utilising quantized gradients with larger dynamic range, like FP8E5M2). There is an experimental API that allows for different primal/tangent dypes; however, the API isn’t really suitable for dropping into Flax models just yet without rewriting a lot of code, but hopefully there will be a simpler API for this in the future.
Optimizer state quantization might help reduce the memory footprint of optimizers like Adam/AdamW, enabling training of larger models with the same hardware constraints.
Weight quantization capabilities for both inference and training would also be useful for both training and model deployment scenarios.
Quantization Approaches
Different quantization techniques offer various tradeoffs that might be worth exploring:
Quantization Paradigms
Different use cases might benefit from different quantization approaches:
I'd like to know whether or not there is currently any roadmap for quantization for NNX, and what features might be supported.
Over the last few days, I have had a look into activation/gradient quantization with Flax NNX and have found myself running into these issues. It would be interesting to hear if anyone has any proposed solutions to these:
Beta Was this translation helpful? Give feedback.
All reactions