Skip to content

Commit 6d0a780

Browse files
committed
improve engine caching and fix bugs
1 parent 92ae286 commit 6d0a780

File tree

6 files changed

+232
-134
lines changed

6 files changed

+232
-134
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
logger = logging.getLogger(__name__)
5353

5454

55-
@needs_refit
55+
@needs_refit # type: ignore[misc]
5656
def construct_refit_mapping(
5757
module: torch.fx.GraphModule,
5858
inputs: Sequence[Input],
@@ -85,7 +85,7 @@ def construct_refit_mapping(
8585
return weight_refit_map
8686

8787

88-
@needs_refit
88+
@needs_refit # type: ignore[misc]
8989
def construct_refit_mapping_from_weight_name_map(
9090
weight_name_map: dict[Any, Any],
9191
state_dict: dict[Any, Any],
@@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map(
128128
return engine_weight_map
129129

130130

131-
@needs_refit
131+
@needs_refit # type: ignore[misc]
132132
def _refit_single_trt_engine_with_gm(
133133
new_gm: torch.fx.GraphModule,
134134
old_engine: trt.ICudaEngine,
@@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm(
211211
raise AssertionError("Refitting failed.")
212212

213213

214-
@needs_refit
214+
@needs_refit # type: ignore[misc]
215215
def refit_module_weights(
216216
compiled_module: torch.fx.GraphModule | ExportedProgram,
217217
new_weight_module: ExportedProgram,
@@ -484,9 +484,10 @@ def refit_module_weights(
484484
weight_name_map=None,
485485
)
486486

487-
# clear EXCLUDE_WEIGHTS flag
487+
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
488488
serialization_config = engine.create_serialization_config()
489489
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
490+
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
490491
serialized_engine = engine.serialize_with_config(serialization_config)
491492

492493
if isinstance(compiled_submodule, PythonTorchTensorRTModule):

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
167167
"engine_capability",
168168
"hardware_compatible",
169169
"refit_identical_engine_weights",
170-
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
171170
"immutable_weights",
172171
"enable_weight_streaming",
173172
"tiling_optimization_level",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,6 @@ def _pretraced_backend(
157157
logger.warning(
158158
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
159159
)
160-
if settings.strip_engine_weights:
161-
logger.error(
162-
"strip_engine_weights arg is not supported for torch.compile()"
163-
)
164160
trt_compiled = compile_module(
165161
gm,
166162
torchtrt_inputs,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch_tensorrt._utils import is_tensorrt_version_supported
3232
from torch_tensorrt.dynamo import _defaults
3333
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
34-
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
34+
from torch_tensorrt.dynamo._settings import CompilationSettings
3535
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
3636
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3737
DYNAMO_CONVERTERS as CONVERTERS,
@@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None:
594594
gc.collect()
595595
torch.cuda.empty_cache()
596596

597-
@needs_refit # type: ignore[misc]
598-
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
599-
# query the cached TRT engine
600-
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
601-
if cached_data is not None: # hit the cache
602-
(
603-
serialized_engine,
604-
self._input_names,
605-
self._output_names,
606-
cached_engine_input_specs,
607-
engine_compilation_settings,
608-
self.weight_name_map,
609-
self.ctx.requires_output_allocator,
610-
) = cached_data
611-
612-
setting_compatiblity, incompattible_settings = settings_are_compatible(
613-
self.compilation_settings, engine_compilation_settings
614-
)
615-
assert (
616-
setting_compatiblity
617-
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
618-
619-
for i, e in enumerate(
620-
[
621-
Input.equivalent_spec(c, i)
622-
for c, i in zip(cached_engine_input_specs, self.input_specs)
623-
]
624-
):
625-
assert (
626-
e
627-
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
628-
629-
_LOGGER.info(
630-
"Found the cached engine that corresponds to this graph. It is directly loaded."
631-
)
632-
633-
# refit the cached engine with the new graph module
634-
if not self.compilation_settings.strip_engine_weights:
635-
runtime = trt.Runtime(TRT_LOGGER)
636-
engine = runtime.deserialize_cuda_engine(serialized_engine)
637-
638-
from torch_tensorrt.dynamo._refit import (
639-
_refit_single_trt_engine_with_gm,
640-
)
641-
642-
_refit_single_trt_engine_with_gm(
643-
new_gm=self.module,
644-
old_engine=engine,
645-
input_list=self.input_specs,
646-
settings=self.compilation_settings,
647-
weight_name_map=self.weight_name_map,
648-
)
649-
650-
# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
651-
# # EXCLUDE_WEIGHTS flag must be cleared
652-
# serialization_config = engine.create_serialization_config()
653-
# serialization_config.clear_flag(
654-
# trt.SerializationFlag.EXCLUDE_WEIGHTS
655-
# )
656-
# serialized_engine = engine.serialize_with_config(
657-
# serialization_config
658-
# )
659-
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller
660-
661-
return TRTInterpreterResult(
662-
engine,
663-
self._input_names,
664-
self._output_names,
665-
self.weight_name_map,
666-
self.ctx.requires_output_allocator,
667-
)
668-
return None
669-
670597
def run(
671598
self,
672599
strict_type_constraints: bool = False,
@@ -682,26 +609,6 @@ def run(
682609
Return:
683610
TRTInterpreterResult
684611
"""
685-
# self.engine_cache could be None if:
686-
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
687-
# 2) both cache_built_engines and reuse_cached_engines are False
688-
if (
689-
self.engine_cache is not None
690-
and not self.compilation_settings.immutable_weights
691-
):
692-
if (
693-
self.compilation_settings.cache_built_engines
694-
or self.compilation_settings.reuse_cached_engines
695-
):
696-
hash_val = self.engine_cache.get_hash(
697-
self.module, self.input_specs, self.compilation_settings
698-
)
699-
700-
if self.compilation_settings.reuse_cached_engines:
701-
interpreter_result = self._pull_cached_engine(hash_val)
702-
if interpreter_result is not None: # hit the cache
703-
return interpreter_result # type: ignore[no-any-return]
704-
705612
self._construct_trt_network_def()
706613
_LOGGER.debug(
707614
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 139 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@
44
import logging
55
from typing import Any, List, NamedTuple, Optional, Sequence
66

7+
import tensorrt as trt
78
import torch
89
from torch_tensorrt._enums import dtype
9-
from torch_tensorrt._features import ENABLED_FEATURES
10+
from torch_tensorrt._features import ENABLED_FEATURES, needs_refit
1011
from torch_tensorrt._Input import Input
1112
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
12-
from torch_tensorrt.dynamo._settings import CompilationSettings
13-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
13+
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
14+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
15+
TRTInterpreter,
16+
TRTInterpreterResult,
17+
)
1418
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1519
from torch_tensorrt.dynamo.utils import (
1620
get_cpu_memory_usage,
1721
get_output_dtypes,
1822
release_host_and_device_memory,
1923
)
24+
from torch_tensorrt.logging import TRT_LOGGER
2025

2126
logger = logging.getLogger(__name__)
2227

@@ -63,6 +68,128 @@ def interpret_module_to_result(
6368
SerializedInterpreterResult
6469
"""
6570

71+
def _insert_engine_to_cache(
72+
hash_val: str, interpreter_result: TRTInterpreterResult
73+
) -> None: # type: ignore[unused-ignore]
74+
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
75+
if engine_cache.check(hash_val) is not None: # type: ignore[union-attr]
76+
logger.info(f"Engine already exists in cache for hash: {hash_val}")
77+
return
78+
if not settings.strip_engine_weights:
79+
# set EXCLUDE_WEIGHTS flag to strip weights
80+
serialization_config = (
81+
interpreter_result.engine.create_serialization_config()
82+
)
83+
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
84+
weight_stripped_serialized_engine = (
85+
interpreter_result.engine.serialize_with_config(serialization_config)
86+
)
87+
else:
88+
weight_stripped_serialized_engine = interpreter_result.engine.serialize()
89+
90+
# Insert weight-stripped engine to cache
91+
engine_cache.insert( # type: ignore[union-attr]
92+
hash_val,
93+
(
94+
weight_stripped_serialized_engine,
95+
interpreter_result.input_names,
96+
interpreter_result.output_names,
97+
inputs,
98+
settings,
99+
interpreter_result.weight_name_map,
100+
interpreter_result.requires_output_allocator,
101+
),
102+
)
103+
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")
104+
105+
@needs_refit # type: ignore[misc]
106+
def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
107+
# query the cached TRT engine
108+
cached_data = engine_cache.check(hash_val) # type: ignore[union-attr]
109+
if cached_data is not None: # hit the cache
110+
(
111+
serialized_engine, # weight-stripped engine
112+
input_names,
113+
output_names,
114+
cached_engine_inputs,
115+
cached_engine_compilation_settings,
116+
weight_name_map,
117+
requires_output_allocator,
118+
) = cached_data
119+
120+
setting_compatiblity, incompattible_settings = settings_are_compatible(
121+
settings, cached_engine_compilation_settings
122+
)
123+
assert (
124+
setting_compatiblity
125+
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"
126+
127+
for i, e in enumerate(
128+
[
129+
Input.equivalent_spec(c, i)
130+
for c, i in zip(cached_engine_inputs, inputs)
131+
]
132+
):
133+
assert (
134+
e
135+
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
136+
137+
logger.info(
138+
"Found the cached engine that corresponds to this graph. It is directly loaded."
139+
)
140+
141+
# refit the cached engine with the new graph module
142+
if not settings.strip_engine_weights:
143+
runtime = trt.Runtime(TRT_LOGGER)
144+
engine = runtime.deserialize_cuda_engine(
145+
serialized_engine
146+
) # weight-stripped engine
147+
148+
from torch_tensorrt.dynamo._refit import (
149+
_refit_single_trt_engine_with_gm,
150+
)
151+
152+
# weight-stripped engine --in place--> weight-included engine
153+
_refit_single_trt_engine_with_gm(
154+
new_gm=module,
155+
old_engine=engine,
156+
input_list=inputs,
157+
settings=settings,
158+
weight_name_map=weight_name_map,
159+
)
160+
161+
# EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
162+
serialization_config = engine.create_serialization_config()
163+
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
164+
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
165+
serialized_engine = engine.serialize_with_config(serialization_config)
166+
# Start from here, the engine is weight-included and refittable
167+
168+
with io.BytesIO() as engine_bytes:
169+
engine_bytes.write(serialized_engine)
170+
serialized_engine = engine_bytes.getvalue()
171+
172+
return SerializedInterpreterResult(
173+
serialized_engine=serialized_engine,
174+
input_names=input_names,
175+
output_names=output_names,
176+
weight_name_map=weight_name_map,
177+
requires_output_allocator=requires_output_allocator,
178+
)
179+
return None
180+
181+
# engine_cache could be None if:
182+
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
183+
# 2) both cache_built_engines and reuse_cached_engines are False
184+
if engine_cache is not None and not settings.immutable_weights:
185+
if settings.cache_built_engines or settings.reuse_cached_engines:
186+
hash_val = engine_cache.get_hash(module, inputs, settings)
187+
188+
if settings.reuse_cached_engines:
189+
serialized_interpreter_result = _pull_cached_engine(hash_val)
190+
if serialized_interpreter_result is not None: # hit the cache
191+
return serialized_interpreter_result # type: ignore[no-any-return]
192+
66193
output_dtypes = infer_module_output_dtypes(
67194
module, truncate_double=settings.truncate_double
68195
)
@@ -86,32 +213,20 @@ def interpret_module_to_result(
86213
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
87214
)
88215

89-
serialized_engine = interpreter_result.engine.serialize()
90-
with io.BytesIO() as engine_bytes:
91-
engine_bytes.write(serialized_engine)
92-
serialized_engine = engine_bytes.getvalue()
93-
logger.debug(
94-
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
95-
)
96-
97216
# Engine caching only for refittable engines
98217
if (
99218
not settings.immutable_weights
100219
and settings.cache_built_engines
101220
and engine_cache is not None
102221
):
103-
hash_val = engine_cache.get_hash(module, inputs, settings)
104-
engine_cache.insert(
105-
hash_val,
106-
(
107-
serialized_engine,
108-
interpreter_result.input_names,
109-
interpreter_result.output_names,
110-
inputs,
111-
settings,
112-
interpreter_result.weight_name_map,
113-
interpreter_result.requires_output_allocator,
114-
),
222+
_insert_engine_to_cache(hash_val, interpreter_result)
223+
224+
serialized_engine = interpreter_result.engine.serialize()
225+
with io.BytesIO() as engine_bytes:
226+
engine_bytes.write(serialized_engine)
227+
serialized_engine = engine_bytes.getvalue()
228+
logger.debug(
229+
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
115230
)
116231

117232
serialized_interpreter_result = SerializedInterpreterResult(
@@ -122,7 +237,7 @@ def interpret_module_to_result(
122237
requires_output_allocator=interpreter_result.requires_output_allocator,
123238
)
124239

125-
return serialized_interpreter_result
240+
return serialized_interpreter_result # type: ignore[no-any-return]
126241

127242

128243
def convert_module(

0 commit comments

Comments
 (0)