Skip to content

[WIP] NaFlex support#1147

Open
mehdidc wants to merge 2 commits intomlfoundations:mainfrom
mehdidc:naflex
Open

[WIP] NaFlex support#1147
mehdidc wants to merge 2 commits intomlfoundations:mainfrom
mehdidc:naflex

Conversation

@mehdidc
Copy link
Copy Markdown
Contributor

@mehdidc mehdidc commented Apr 14, 2026

@rwightman first try, I think it would be cool to have it open_clip, especially that most of the required code is already in timm.
I am doing some first experiments to validate, could you maybe have a look and
check if there is something completely wrong in the impl ?

@rwightman @JeniaJitsev @marianna13 example usage to try it out:

...
srun python -u src/open_clip_train/main.py \
    --save-frequency 1 \
    --dataset-type webdataset --dataset-resampled\
    --train-data "..." \
    --warmup 2000 \
    --report-to=tensorboard \
    --epochs=1\
    --workers=4 \
    --model naflex_ViT-B-16 \
    --name  naflex_ViT-B-16\
    --logs logs \
    --seed 0 \
    --ddp-static-graph \
    --local-loss \
    --gather-with-grad \
    --lr 0.001\
    --save-most-recent \
    --precision amp_bfloat16 \
    --grad-checkpointing \
    --grad-clip-norm 1 \
    --wd 0.2 \
    --beta1 0.9 \
    --beta2 0.98 \
    --resume latest \
    --use-naflex \
    --naflex-num-train-image-tokens=$((128 * 10**6 * 196))\
    --naflex-patch-sizes 16\
    --naflex-seq-lens 128 256 576 784 1024\
    --naflex-max-image-tokens-per-batch $((512 * 196))\
    --aug-cfg use_timm=True naflex=True\
    --log-every-n-steps 1

i.e here we train for same token budget as a B-16 trained on 224x224 images (so 196 tokens per image) with 512 local batch size and for 128M total samples seen.

- data.py: adapt timm's NaFlexDatasetWrapper as a webdataset PipelineStage to control batching
- timm_model.py and train.py: support dict input to support pre-patched inputs
- train.py: update logging to use variable batch sizes
- transform.py: use a transform factory following timm to apply them to
  different seq lens
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a3faf5ac82

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

aug_cfg_dict.pop('gray_scale_prob', None)

train_transform = create_transform(
train_transform = partial(create_transform,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Return a callable image transform, not a factory partial

The use_timm training path now stores partial(create_transform, ...) instead of the transform object itself. In the non-NaFlex loader, each sample calls preprocess_img(img), so that image is passed as a positional argument to create_transform, conflicting with the already-bound input_size keyword and causing runtime failure for --aug-cfg use_timm=True training.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant