Replies: 1 comment 1 reply
-
I wonder if the examples here might set you in the right direction? https://github.com/jax-ml/jax-llm-examples These are minimal examples of various well-known LLM architectures implemented in pure JAX, rather than using any framework. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone!
I'm working on a research project around LLM RLHF training. I want to find a JAX-based library that allows me to do that. I found around 10 of them, with the most promising ones being MaxText and AXLearn; however, for both of them the setup is quite complicated. They are both meant for large-scale training of big LLMs on cloud hardware, while I prefer a dead-simple library that does everything locally on my NVidia GPU, and I'm okay with the language model being a lot smaller and modest.
Can anyone recommend such a library?
Thanks,
Ram.
Beta Was this translation helpful? Give feedback.
All reactions