Skip to content

Conversation

@EIFY
Copy link
Contributor

@EIFY EIFY commented Nov 26, 2024

torch.nn.init.trunc_normal_() defaults to truncation at (a, b), not (a * std, b * std). So to conform to JAX's variance_scaling(..., distribution="truncated_normal", ...) we need to multiply by std ourselves. We can see this by initializing a test model. Here is the repo's JAX ViT-S/16:

>>> import jax.numpy
>>> import jax.random
>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> x = jax.numpy.zeros((1, 224, 224, 3), jax.numpy.float32)
>>> params = vit.init(jax.random.key(0), x)
>>> for w in [params['params']['conv_patch_extract']['kernel'], params['params']['pre_logits']['kernel']]:
...   print(w.min(), w.max())
...
-0.08204417 0.08203908
-0.11602508 0.116011634

Here is the repo's PyTorch ViT-S/16 before the fix:

>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> for w in [vit.conv_patch_extract.weight, vit.pre_logits.weight]:
...   print(w.min(), w.max())
...
tensor(-0.2119, grad_fn=<MinBackward1>) tensor(0.1907, grad_fn=<MaxBackward1>)
tensor(-0.2749, grad_fn=<MinBackward1>) tensor(0.2512, grad_fn=<MaxBackward1>)

Here is the repo's PyTorch ViT-S/16 after the fix:

>>> from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import ViT
>>> from algorithmic_efficiency.workloads.imagenet_vit.workload import decode_variant
>>> vit = ViT(**decode_variant('S/16'))
>>> for w in [vit.conv_patch_extract.weight, vit.pre_logits.weight]:
...   print(w.min(), w.max())
... 
tensor(-0.0820, grad_fn=<MinBackward1>) tensor(0.0820, grad_fn=<MaxBackward1>)
tensor(-0.1160, grad_fn=<MinBackward1>) tensor(0.1160, grad_fn=<MaxBackward1>)

Affected current workloads include imagenet_vit, imagenet_resnet, fastmri, and ogbg, along with (retired? test?) workloads cifar and mnist.

I hope this bug doesn't drastically upend the results so far but I don't know 😬

@EIFY EIFY requested a review from a team as a code owner November 26, 2024 22:34
@github-actions
Copy link

github-actions bot commented Nov 26, 2024

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@EIFY
Copy link
Contributor Author

EIFY commented Nov 27, 2024

recheck

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Nov 27, 2024

Hi Jason, thanks for sending this PR in. This is a very interesting and good catch!
Did you sign the CLA with your github email and username to unblock the CLA check?

It's hard to say at the moment what the final effect is of this difference in initialization. My guess is that it is probably not going to upend the results but we can double check this.

@EIFY
Copy link
Contributor Author

EIFY commented Nov 27, 2024

Hi Jason, thanks for sending this PR in. This is a very interesting and good catch! Did you sign the CLA with your github email and username to unblock the CLA check?

I have signed and emailed the CLA. I think the system has identified me as a signee, just not rerun the pull_request_target automatically. I have triggered it again with no-change amend.

torch.nn.init.trunc_normal_() defaults to truncation at (a, b),
not (a * std, b * std).
@priyakasimbeg priyakasimbeg changed the base branch from main to dev December 4, 2024 20:11
@priyakasimbeg priyakasimbeg self-requested a review December 12, 2024 18:18
@priyakasimbeg priyakasimbeg merged commit fe90379 into mlcommons:dev Dec 12, 2024
33 of 36 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Dec 12, 2024
@EIFY EIFY deleted the torch-init-fix branch December 12, 2024 20:12
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants