Conversation
- data.py: adapt timm's NaFlexDatasetWrapper as a webdataset PipelineStage to control batching - timm_model.py and train.py: support dict input to support pre-patched inputs - train.py: update logging to use variable batch sizes - transform.py: use a transform factory following timm to apply them to different seq lens
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a3faf5ac82
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| aug_cfg_dict.pop('gray_scale_prob', None) | ||
|
|
||
| train_transform = create_transform( | ||
| train_transform = partial(create_transform, |
There was a problem hiding this comment.
Return a callable image transform, not a factory partial
The use_timm training path now stores partial(create_transform, ...) instead of the transform object itself. In the non-NaFlex loader, each sample calls preprocess_img(img), so that image is passed as a positional argument to create_transform, conflicting with the already-bound input_size keyword and causing runtime failure for --aug-cfg use_timm=True training.
Useful? React with 👍 / 👎.
@rwightman first try, I think it would be cool to have it open_clip, especially that most of the required code is already in timm.
I am doing some first experiments to validate, could you maybe have a look and
check if there is something completely wrong in the impl ?
@rwightman @JeniaJitsev @marianna13 example usage to try it out:
... srun python -u src/open_clip_train/main.py \ --save-frequency 1 \ --dataset-type webdataset --dataset-resampled\ --train-data "..." \ --warmup 2000 \ --report-to=tensorboard \ --epochs=1\ --workers=4 \ --model naflex_ViT-B-16 \ --name naflex_ViT-B-16\ --logs logs \ --seed 0 \ --ddp-static-graph \ --local-loss \ --gather-with-grad \ --lr 0.001\ --save-most-recent \ --precision amp_bfloat16 \ --grad-checkpointing \ --grad-clip-norm 1 \ --wd 0.2 \ --beta1 0.9 \ --beta2 0.98 \ --resume latest \ --use-naflex \ --naflex-num-train-image-tokens=$((128 * 10**6 * 196))\ --naflex-patch-sizes 16\ --naflex-seq-lens 128 256 576 784 1024\ --naflex-max-image-tokens-per-batch $((512 * 196))\ --aug-cfg use_timm=True naflex=True\ --log-every-n-steps 1i.e here we train for same token budget as a B-16 trained on 224x224 images (so 196 tokens per image) with 512 local batch size and for 128M total samples seen.