-
Notifications
You must be signed in to change notification settings - Fork 265
[PT] BatchNorm adaptation #3726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[PT] BatchNorm adaptation #3726
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements BatchNorm adaptation functionality for PyTorch models after pruning. The feature allows updating BatchNorm layer statistics using a calibration dataset, which can improve model accuracy post-pruning without full fine-tuning.
Key Changes:
- Added
batch_norm_adaptationfunction to adapt BatchNorm statistics after pruning - Introduced
set_batchnorm_train_onlycontext manager to selectively set only BatchNorm layers to training mode - Updated type hints to use TypeVar for better type preservation in pruning functions
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/nncf/torch/function_hook/pruning/batch_norm_adaptation.py | New module implementing BatchNorm adaptation logic with context manager |
| src/nncf/pruning/prune_model.py | Added public API for batch_norm_adaptation with backend routing |
| src/nncf/init.py | Exported batch_norm_adaptation function in public API |
| tests/torch2/function_hook/pruning/test_bn_adaptation.py | Added comprehensive tests for BatchNorm adaptation functionality |
| examples/pruning/torch/resnet18/main.py | Integrated bn_adaptation mode into ResNet18 pruning example |
| examples/pruning/torch/resnet18/README.md | Updated documentation with bn_adaptation usage instructions |
| src/nncf/torch/function_hook/pruning/prune_model.py | Updated type hints using TypeVar for type preservation |
| src/nncf/torch/function_hook/pruning/magnitude/algo.py | Updated type hints using TypeVar for type preservation |
| src/nncf/torch/function_hook/pruning/rb/algo.py | Updated type hints using TypeVar for type preservation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "--mode", | ||
| type=str, | ||
| choices=["magnitude", "rb"], | ||
| choices=["magnitude", "bn_adaptation", "rb"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to give users choice to use bn adaptation regardless of the pruning algorithm? I mean, a --bn_adaptation option would be simple IMHO
Especially because in fact there is a choice: finetune the pruned model or just apply bn adaptation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m not sure I see a reason to do that.
What pipeline for bn adaptation you suggest for rb pruning?
| if isinstance(input_data, dict): | ||
| model(**input_data) | ||
| elif isinstance(input_data, tuple): | ||
| model(*input_data) | ||
| else: | ||
| model(input_data) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse PTEngine here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in PT Engine used model.eval()
Changes
Add
nncf.batch_norm_adaptationfunction.Add mode
bn_adaptaionin exampleRelated tickets
174483
Tests
https://github.com/openvinotoolkit/nncf/actions/runs/19162855749/job/54776694802