Skip to content

Conversation

@rka97
Copy link
Contributor

@rka97 rka97 commented Apr 3, 2025

This is for the LM workload.

…JAX and PyTorch, also unify initialization to be the same in both
Copy link
Contributor

@priyakasimbeg priyakasimbeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second round of small requested changes.

Perhaps something we should discuss, we need a more descriptive name for the workload. E.g. fineweb_edu_lm. What do you all think? @Niccolo-Ajroldi @rka97

@rka97 rka97 force-pushed the lm_workload branch 2 times, most recently from 21d9b3d to ae4fc8d Compare November 30, 2025 05:48
rka97 and others added 6 commits December 1, 2025 06:49
- Introduced DTYPE enum to standardize data types (FLOAT32, FLOAT16, BFLOAT16) for JAX and PyTorch.
- Updated input pipelines and model definitions in CIFAR and ImageNet workloads to utilize mixed precision.
- Implemented casting policies for parameters and inputs using jmp and torch.autocast.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants