-
Notifications
You must be signed in to change notification settings - Fork 241
[5725362] AutoCast Fixes for models with external data #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
| has_external_data = any( | ||
| init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL | ||
| for init in self.model.graph.initializer | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse this function to check for external data?
| def has_external_data(onnx_model_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane I'm not sure I want to introduce dependencies from modelopt.torch here just for this.
Since modelopt/torch/_deploy/utils/torch_onnx.py is already importing quite a few utils from modelopt.onnx.utils, how about I move this function to modelopt.onnx.utils and import it in modelopt.torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane please revisit - since I edited torch utils, I now need a review from modelopt-torch-deploy-codeowners 🙏
Signed-off-by: Gal Hubara Agam <[email protected]>
Signed-off-by: Gal Hubara Agam <[email protected]>
Signed-off-by: Gal Hubara Agam <[email protected]>
1a0711c to
7a2d91a
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThe changes enhance ONNX model handling with external data support by introducing detection and branching logic in ReferenceRunner, adding model size logging to PrecisionConverter, and updating utils to support external data workflows with modified validation and return signatures. Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as ReferenceRunner
participant GetRunner as _get_ort_runner()
participant Model as ONNX Model
participant FS as File System
participant ORT as ONNXRuntime
Runner->>GetRunner: run(model)
GetRunner->>Model: Check initializers for external data
alt External Data Detected or Size > 2GB
GetRunner->>FS: Create temp .onnx file
GetRunner->>Model: Save with external data enabled
GetRunner->>FS: Write temp ONNX file
GetRunner->>ORT: Create InferenceSession from file
ORT-->>GetRunner: Session ready
else In-Memory Path
GetRunner->>ORT: Create Session via BytesFromOnnx
ORT-->>GetRunner: Session ready
end
GetRunner-->>Runner: Return OnnxrtRunner(session)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@modelopt/onnx/autocast/precisionconverter.py`:
- Around line 85-88: Remove the temporary debug method print_byte_size: it uses
print() with a "GAGAM" prefix and lacks a docstring; either delete this method
from the class or replace it with a properly documented utility that uses the
module logger (e.g., logging.getLogger(__name__)) and a D102-compliant
docstring, and if keeping functionality, rename to something descriptive (e.g.,
log_model_byte_size) and call self.model.SerializeToString() to compute size
before logging the result instead of printing.
- Around line 1298-1303: Remove the debug instrumentation by deleting the calls
to self.print_byte_size(...) inside the _sanity_check method and remove the
print_byte_size method definition entirely; ensure _sanity_check only performs
onnx_utils.check_model(self.model) (and any existing sanity_ok handling) without
emitting byte-size prints, and run tests to confirm no remaining references to
print_byte_size remain in the class.
In `@modelopt/onnx/autocast/referencerunner.py`:
- Around line 160-165: The temp ONNX file created with
tempfile.NamedTemporaryFile(delete=False) and passed to
onnx_utils.save_onnx/InfernceSession is never removed; store the path (e.g.,
self._temp_model_path) and ensure cleanup once the session is no longer needed
by either creating the temp inside a TemporaryDirectory context (mirroring
utils.check_model) that you keep open for the session lifetime or add a
dedicated cleanup method (e.g., close()/cleanup_temp_model) that deletes the
file and call it when the runner is disposed; optionally register that cleanup
with atexit as a fallback to avoid accumulating files.
In `@modelopt/onnx/utils.py`:
- Line 661: The line force-setting save_as_external_data = True is a temporary
debug override that bypasses the prior size-based logic; remove this hardcoded
assignment (the "save_as_external_data = True # GAGAM: for debug" statement) so
the function/method that computes save_as_external_data earlier in the scope can
determine the value normally, and ensure no other debug-only overrides remain in
the same function or surrounding block.
🧹 Nitpick comments (3)
modelopt/onnx/utils.py (1)
647-649: Consider reverting the log level fromwarningtodebug.This message logs model size and external data usage for every save operation. Using
warninglevel will generate noise in production logs for what is essentially informational/diagnostic output. The originaldebuglevel was more appropriate unless there's a specific reason to always surface this information.Suggested change
- logger.warning( + logger.debug( f"Model size: {model_size} bytes, using external data: {save_as_external_data}" )modelopt/onnx/autocast/referencerunner.py (2)
131-152: Consider extracting external data detection to a shared utility.This logic duplicates functionality that exists in
modelopt/torch/_deploy/utils/torch_onnx.py(has_external_dataandcheck_model_uses_external_data). Per a past review comment, consider moving the detection logic tomodelopt/onnx/utils.pyso it can be reused across the codebase.The current implementation is correct but consolidating would improve maintainability.
166-166: Lambda closure capturessessioncorrectly, but consider clarity.The lambda
lambda: sessionworks becausesessionis captured by reference. However, this pattern can be confusing. TheOnnxrtRunnerexpects a callable that returns a session — passing the session directly via a lambda is acceptable but slightly unusual.
Signed-off-by: Gal Hubara Agam <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #731 +/- ##
==========================================
- Coverage 74.19% 74.15% -0.04%
==========================================
Files 192 191 -1
Lines 19238 19258 +20
==========================================
+ Hits 14273 14281 +8
- Misses 4965 4977 +12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Gal Hubara Agam <[email protected]>
gcunhase
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Signed-off-by: Gal Hubara Agam <[email protected]>
2ae99d0 to
00ea80c
Compare
| tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) | ||
| tmp_file.close() | ||
| tmp_file_path = tmp_file.name | ||
| onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) | ||
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | ||
| session = ort.InferenceSession(tmp_file_path, providers=self.providers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI Polygraphy's SaveOnnx can handle models with external data. Also SessionFromOnnx can accept paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @pranavm-nvidia . I'll take a look and see if I can refactor.
If it's an quick fix for you - feel free to push a commit to this PR, and I'll review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a suggestion with the change. One thing I'm not sure about - does your onnx_utils.save_onnx do anything special besides saving the model? On quick inspection, it seems like it's also setting a custom IR version. If that's still required, you'll probably need to add a line like:
modified_model.ir_version = 10prior to calling Polygraphy's save_onnx.
| try: | ||
| # Try to estimate size by serializing the model | ||
| # If it fails or exceeds 2GB, we need file-based approach | ||
| model_size = len(self.model.SerializeToString()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_size = len(self.model.SerializeToString()) | |
| model_size = model.ByteSize() |
|
|
||
| def _get_ort_runner(self, model): | ||
| import onnxruntime as ort | ||
| from polygraphy.backend.onnx import BytesFromOnnx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from polygraphy.backend.onnx import BytesFromOnnx | |
| from polygraphy.backend.onnx import BytesFromOnnx, save_onnx |
| if has_external_data: | ||
| logger.debug("Model has external data, using file-based approach") | ||
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | ||
| modified_model = model() | ||
|
|
||
| # Use a persistent temp file to handle external data files properly | ||
| tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) | ||
| tmp_file.close() | ||
| tmp_file_path = tmp_file.name | ||
| onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) | ||
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | ||
| session = ort.InferenceSession(tmp_file_path, providers=self.providers) | ||
| runners = [OnnxrtRunner(lambda: session)] | ||
|
|
||
| else: | ||
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | ||
| logger.debug("Model has no external data, using BytesFromOnnx approach") | ||
| serialize_onnx = BytesFromOnnx(model) | ||
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | ||
| runners = [OnnxrtRunner(build_onnxrt_session)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if has_external_data: | |
| logger.debug("Model has external data, using file-based approach") | |
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | |
| modified_model = model() | |
| # Use a persistent temp file to handle external data files properly | |
| tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) | |
| tmp_file.close() | |
| tmp_file_path = tmp_file.name | |
| onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) | |
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | |
| session = ort.InferenceSession(tmp_file_path, providers=self.providers) | |
| runners = [OnnxrtRunner(lambda: session)] | |
| else: | |
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | |
| logger.debug("Model has no external data, using BytesFromOnnx approach") | |
| serialize_onnx = BytesFromOnnx(model) | |
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | |
| runners = [OnnxrtRunner(build_onnxrt_session)] | |
| if has_external_data: | |
| logger.debug("Model has external data, using file-based approach") | |
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | |
| modified_model = model() | |
| # Use a persistent temp file to handle external data files properly | |
| outdir = tempfile.TemporaryDirectory() | |
| tmp_file_path = os.path.join(outdir.name, "tmp_model.onnx") | |
| save_onnx(modified_model, tmp_file_path, external_data_path="ext.data") | |
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | |
| build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers) | |
| else: | |
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | |
| logger.debug("Model has no external data, using BytesFromOnnx approach") | |
| serialize_onnx = BytesFromOnnx(model) | |
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | |
| runners = [OnnxrtRunner(build_onnxrt_session)] |
What does this PR do?
Type of change: Bug fix
Overview: Fix AutoCast ReferenceRunner to handle large models.
Models above 2GB cannot be serialized to string, which is what polygraphy is doing under the hood. Use a temporary file instead to save the modified onnx with all tensors marked as outputs.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.