This repository host LLM tutorials based on JAX.
- 01.miniGPT: build a miniGPT model from scratch and pretrain it on the TinyStories dataset
- 02.GPT2 pretraning: pretrain 124M and 354M GPT2 on the OpenWebText dataset (inspired by nanoGPT)
- 03.GPT2 instruction tuning: instruction tune the 124M pretrained GPT2 from above and from Hugging Face
- 04.GPT2_LoRA: use LoRA to instruct tune the 124M pretrained GPT2
- 05.GPT2 DPO: use Direct Preference Optimization to align the 124M pretrained GPT2
- 06.Loading the Llama 3.2 1B model: load an existing model from Hugging Face and run inference