Skip to content

Commit 5e1b314

Browse files
committed
Moved some tests to L1
1 parent 9ac07a9 commit 5e1b314

File tree

9 files changed

+136
-104
lines changed

9 files changed

+136
-104
lines changed

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
cd tests/py
137137
cd dynamo
138138
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
139-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
139+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/L0/
140140
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
141141
popd
142142
@@ -229,6 +229,8 @@ jobs:
229229
pushd .
230230
cd tests/py/dynamo
231231
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
232+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/L1/
233+
232234
popd
233235
234236
L1-dynamo-compile-tests:

py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
logger = logging.getLogger(__name__)
6161

6262
MAX_NUM_OF_ENGINES = 40
63+
ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER = 4
6364

6465

6566
class ResourcePartitioner(_SplitterBase): # type: ignore
@@ -87,8 +88,9 @@ def __init__(
8788
assert isinstance(module, torch.fx.GraphModule)
8889

8990
self.module = module
90-
self.cpu_memory_budget = (
91-
cpu_memory_budget
91+
used_rss: int = psutil.Process().memory_info().rss
92+
self.remaining_memory_budget = (
93+
cpu_memory_budget - used_rss
9294
if cpu_memory_budget is not None
9395
else psutil.virtual_memory().available
9496
)
@@ -114,6 +116,12 @@ def partition_graph(self) -> torch.fx.GraphModule:
114116
"""
115117
# Delegate nodes based on operator coverage
116118
subgraphs = self.put_nodes_into_subgraphs()
119+
sizes = self.size_of_subgraphs(subgraphs)
120+
if (
121+
sum(sizes) * ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER
122+
< self.remaining_memory_budget
123+
):
124+
return self.module
117125

118126
subgraphs = self.break_subgraphs(
119127
subgraphs, subgraph_size_budget=self.calculate_size_budget()
@@ -172,7 +180,8 @@ def check_topological_order(self, subgraphs: List[Subgraph]) -> bool:
172180
return True
173181

174182
def calculate_size_budget(
175-
self, engine_compilation_memory_usage_multiplier: int = 4
183+
self,
184+
engine_compilation_memory_usage_multiplier: int = ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER,
176185
) -> int:
177186
"""Compute the per-engine size budget in bytes.
178187
@@ -188,13 +197,9 @@ def calculate_size_budget(
188197
int: Budget in bytes for a single accelerated subgraph.
189198
"""
190199

191-
used_rss: int = psutil.Process().memory_info().rss
192-
available_rss = (
193-
self.cpu_memory_budget
194-
if self.not_set_limit
195-
else self.cpu_memory_budget - used_rss
200+
return (
201+
self.remaining_memory_budget // engine_compilation_memory_usage_multiplier
196202
)
197-
return available_rss // engine_compilation_memory_usage_multiplier
198203

199204
def break_subgraphs(
200205
self, subgraphs: List[Subgraph], subgraph_size_budget: int
@@ -229,7 +234,7 @@ def break_subgraphs(
229234
else:
230235
raise ValueError(
231236
"CPU memory budget is too small to compile the model. "
232-
+ f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. "
237+
+ f"CPU memory budget: {self.remaining_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. "
233238
+ "Consider setting cpu_memory_budget to a larger value."
234239
)
235240
for subgraph, size in zip(subgraphs, sizes):
@@ -548,12 +553,15 @@ def resource_partition(
548553
setattr(gm, name, partitioned_graph)
549554

550555
for name, module in list(gm.named_children()):
556+
split = False
551557
if "_run_on_acc" in name:
552558
for subname, submodule in module.named_children():
553559
if "resource_split" in subname:
560+
split = True
554561
setattr(gm, subname, submodule)
555-
_inline_module(gm, name)
556-
delattr(gm, name)
562+
if split:
563+
_inline_module(gm, name)
564+
delattr(gm, name)
557565

558566
gm.recompile()
559567
return gm
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.fx.passes.splitter_base import Subgraph
4+
from torch.ops import aten
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
from torch_tensorrt.dynamo import partitioning
7+
from torch_tensorrt.dynamo.conversion import CompilationSettings
8+
from torch_tensorrt.dynamo.lowering import (
9+
get_decompositions,
10+
post_lowering,
11+
pre_export_lowering,
12+
)
13+
from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering
14+
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
15+
ResourcePartitioner,
16+
)
17+
18+
19+
class TestResourcePartitioning(TestCase):
20+
def test_atomic_subgraph_correction(self):
21+
class net(nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.conv1 = nn.Conv2d(3, 3, 3, padding=1)
25+
self.bn1 = nn.BatchNorm2d(3)
26+
self.relu = nn.ReLU()
27+
self.fc = nn.Linear(3 * 224 * 224, 10)
28+
29+
def forward(self, x):
30+
x = self.conv1(x)
31+
x = self.bn1(x)
32+
x = self.relu(x)
33+
x = torch.flatten(x, 1)
34+
x = self.fc(x)
35+
return x
36+
37+
model = net().eval()
38+
model.to("cuda")
39+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
40+
41+
enabled_precisions = {torch.float}
42+
use_python_runtime = False
43+
44+
exp_program = torch.export.export(model, tuple(inputs))
45+
46+
compilation_options = {
47+
"use_python_runtime": use_python_runtime,
48+
"enabled_precisions": enabled_precisions,
49+
"min_block_size": 1,
50+
"immutable_weights": True,
51+
"reuse_cached_engines": False,
52+
"enable_resource_partitioning": True,
53+
}
54+
settings = CompilationSettings(**compilation_options)
55+
56+
exported_program = pre_export_lowering(exp_program, settings)
57+
exported_program = exported_program.run_decompositions(
58+
get_decompositions(False)
59+
)
60+
61+
gm = exported_program.module()
62+
gm = post_lowering(gm, settings)
63+
64+
partitioned_module, supported_ops = partitioning.fast_partition(
65+
gm,
66+
min_block_size=settings.min_block_size,
67+
torch_executed_ops=settings.torch_executed_ops,
68+
require_full_compilation=settings.require_full_compilation,
69+
skip_fusion=True,
70+
)
71+
72+
for name, _ in partitioned_module.named_children():
73+
submodule = getattr(partitioned_module, name)
74+
if (
75+
not isinstance(submodule, torch.fx.graph_module.GraphModule)
76+
or "_run_on_acc" not in name
77+
):
78+
continue
79+
partitioner = ResourcePartitioner(
80+
submodule,
81+
submodule_name=name,
82+
cpu_memory_budget=2 * 1024 * 1024 * 1024,
83+
)
84+
subgraphs = partitioner.put_nodes_into_subgraphs()
85+
new_subgraphs = []
86+
current_subgraph = []
87+
# Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group.
88+
for node in subgraphs[0].nodes:
89+
if node.op == "call_function" and node.target == aten.relu.default:
90+
new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph))
91+
current_subgraph = []
92+
current_subgraph.append(node)
93+
if current_subgraph:
94+
new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph))
95+
96+
leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes)
97+
broken_fusion = partitioner.step_if_break_fusion(
98+
new_subgraphs,
99+
leaf_node,
100+
set(new_subgraphs[0].nodes),
101+
set(new_subgraphs[1].nodes),
102+
)
103+
# The fusion was broken
104+
assert broken_fusion
105+
106+
# The fusion should be fixed after the step
107+
partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
108+
109+
break
110+
111+
112+
if __name__ == "__main__":
113+
run_tests()

tests/py/dynamo/partitioning/test_resource_partitioning.py renamed to tests/py/dynamo/partitioning/L1/test_resource_partitioning.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -326,97 +326,6 @@ def forward(self, x):
326326

327327
torch._dynamo.reset()
328328

329-
def test_atomic_subgraph_correction(self):
330-
class net(nn.Module):
331-
def __init__(self):
332-
super().__init__()
333-
self.conv1 = nn.Conv2d(3, 3, 3, padding=1)
334-
self.bn1 = nn.BatchNorm2d(3)
335-
self.relu = nn.ReLU()
336-
self.fc = nn.Linear(3 * 224 * 224, 10)
337-
338-
def forward(self, x):
339-
x = self.conv1(x)
340-
x = self.bn1(x)
341-
x = self.relu(x)
342-
x = torch.flatten(x, 1)
343-
x = self.fc(x)
344-
return x
345-
346-
model = net().eval()
347-
model.to("cuda")
348-
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
349-
350-
enabled_precisions = {torch.float}
351-
use_python_runtime = False
352-
353-
exp_program = torch.export.export(model, tuple(inputs))
354-
355-
compilation_options = {
356-
"use_python_runtime": use_python_runtime,
357-
"enabled_precisions": enabled_precisions,
358-
"min_block_size": 1,
359-
"immutable_weights": True,
360-
"reuse_cached_engines": False,
361-
"enable_resource_partitioning": True,
362-
}
363-
settings = CompilationSettings(**compilation_options)
364-
365-
exported_program = pre_export_lowering(exp_program, settings)
366-
exported_program = exported_program.run_decompositions(
367-
get_decompositions(False)
368-
)
369-
370-
gm = exported_program.module()
371-
gm = post_lowering(gm, settings)
372-
373-
partitioned_module, supported_ops = partitioning.fast_partition(
374-
gm,
375-
min_block_size=settings.min_block_size,
376-
torch_executed_ops=settings.torch_executed_ops,
377-
require_full_compilation=settings.require_full_compilation,
378-
skip_fusion=True,
379-
)
380-
381-
for name, _ in partitioned_module.named_children():
382-
submodule = getattr(partitioned_module, name)
383-
if (
384-
not isinstance(submodule, torch.fx.graph_module.GraphModule)
385-
or "_run_on_acc" not in name
386-
):
387-
continue
388-
partitioner = ResourcePartitioner(
389-
submodule,
390-
submodule_name=name,
391-
cpu_memory_budget=2 * 1024 * 1024 * 1024,
392-
)
393-
subgraphs = partitioner.put_nodes_into_subgraphs()
394-
new_subgraphs = []
395-
current_subgraph = []
396-
# Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group.
397-
for node in subgraphs[0].nodes:
398-
if node.op == "call_function" and node.target == aten.relu.default:
399-
new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph))
400-
current_subgraph = []
401-
current_subgraph.append(node)
402-
if current_subgraph:
403-
new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph))
404-
405-
leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes)
406-
broken_fusion = partitioner.step_if_break_fusion(
407-
new_subgraphs,
408-
leaf_node,
409-
set(new_subgraphs[0].nodes),
410-
set(new_subgraphs[1].nodes),
411-
)
412-
# The fusion was broken
413-
assert broken_fusion
414-
415-
# The fusion should be fixed after the step
416-
partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
417-
418-
break
419-
420329
def test_resource_partitioning_with_global_capability_partitioning(self):
421330
class net(nn.Module):
422331
def __init__(self):

0 commit comments

Comments
 (0)