Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Nov 24, 2025

Description

As I requested, TensorRT 10.14 added an argument trt.SerializationFlag.INCLUDE_REFIT to allow refitted engines to keep refittable. That means engines can be refitted multiple times. Based on the capability, this PR enhances the existing engine caching and refitting features as follows:

  1. To save hard disk space, engine caching will only save weight-stripped engines on disk regardless of compilation_settings.strip_engine_weights. Then, when users pull out the cached engine, it will be automatically refitted and kept refittable.
  2. Compiled TRT modules can be refitted multiple times with refit_module_weights(). e.g.:
for _ in range(3):
    trt_gm = refit_module_weights(trt_gm, exp_program)
  1. Due to some changes, the insertion and pulling of cached engines are located in different places, which causes 🐛 [Bug] Engine cache failed on torch.compile backend=tensorrt #3909. This PR unified the insertion and pulling in _conversion.py.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Nov 24, 2025
@meta-cla meta-cla bot added the cla signed label Nov 24, 2025
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Nov 24, 2025
@narendasan
Copy link
Collaborator

@cehongwang please take a pass so we have multiple eyes on this PR

@zewenli98 zewenli98 force-pushed the improve_engine_caching branch from a54907e to ea81677 Compare December 4, 2025 18:38
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would a torch.compile use try to use strip weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the warning back. Not sure why strip_engine_weights arg doesn't work for torch.compile()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just doenst make sense. Torch compile is not serializable. So why would you ever want a callable that doesnt have the weights in it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that for torch.compile, _pretraced_backend() calls compile_module(), which is the same as the AOT to build TRT engines. If we pass in settings.strip_engine_weights, doesn't it tell the builder to build a weight-stripped engine?

@narendasan narendasan linked an issue Dec 9, 2025 that may be closed by this pull request
@github-actions github-actions bot added the component: core Issues re: The core compiler label Dec 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests component: torch_compile

Projects

None yet

Development

Successfully merging this pull request may close these issues.

📖 [Story] Weightless Engine Building

3 participants