Skip to content

Issue with sharding the model and data batches #222

@heydaari

Description

@heydaari

I have tried to change the single processor ViT training to Data Parallel on kaggle TPUs with (8, 1) sharding.
I used the mechanism introduced by the notebook Train a MiniGPT by jax-ai-stack but im facing some issues in shrading the batches ( and maybe further in collaboration of batches and model).

This is the gist of the notebook and errors: (note that i skip the weight conversion from HF)
https://gist.github.com/heydaari/854e00c28f57806f0f7ac0818f013bbd

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions