I have tried to change the single processor ViT training to Data Parallel on kaggle TPUs with (8, 1) sharding.
I used the mechanism introduced by the notebook Train a MiniGPT by jax-ai-stack but im facing some issues in shrading the batches ( and maybe further in collaboration of batches and model).
This is the gist of the notebook and errors: (note that i skip the weight conversion from HF)
https://gist.github.com/heydaari/854e00c28f57806f0f7ac0818f013bbd