-
Notifications
You must be signed in to change notification settings - Fork 8
Compare microbatch forward outputs and gradients #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #246, branch: xmfan/stack/20
0813cd5 to
580144b
Compare
72c4ffc to
79bf049
Compare
stack-info: PR: #246, branch: xmfan/stack/20
79bf049 to
4b0b462
Compare
stack-info: PR: #246, branch: xmfan/stack/20
4b0b462 to
b9d82ef
Compare
stack-info: PR: #246, branch: xmfan/stack/20
b9d82ef to
adbd32c
Compare
stack-info: PR: #246, branch: xmfan/stack/20
adbd32c to
f984301
Compare
6e8451c to
59670d0
Compare
stack-info: PR: #246, branch: xmfan/stack/20
f984301 to
e5c0227
Compare
stack-info: PR: #246, branch: xmfan/stack/20
e5c0227 to
7c45448
Compare
|
granted the rng affects the grads, why does the diff show 'none' rather than a different hash? |
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| with AutoParallel( | ||
| model, input_fn, mesh, dynamic=True, numerics_logger=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be numerics_logger = numerics_logger?
| return | ||
|
|
||
| rank = torch.distributed.get_rank() | ||
| if rank == 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you somehow not hardcode this
| action: _Action, | ||
| ctx: _PipelineContext, | ||
| numerics_logs: Optional[list[str]] = None, | ||
| forward_hook: Callable | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Optional[Callable]
| if self.rank == 0: | ||
| print(f"Weight hashes written to {path}") | ||
|
|
||
| def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is num_world_stages?
| rank = torch.distributed.get_rank() | ||
| if rank == 4: | ||
| numerics_logger.log_diff( | ||
| output, rank=4, prefix=f"mb{action.microbatch_index} fwd out" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, very confusing. Also do we care about pp_rank or global rank? Finally v style schedules will have last stage on rank 0?
If we land #250 first it fixes the grad issue. |
There was a bug in gradient accumulation that is fixed by #250 |
Stacked PRs:
Currently the forward matches per microbatch (no batch invariance)
But for the backward, all grads are None
Intended usage:
Currently, fw ins are the same, but the forward is being ran with different rng state between the two setups so there's some numerical differences