Skip to content

Commit 8e2d038

Browse files
committed
up
1 parent aa0cafb commit 8e2d038

File tree

1 file changed

+13
-23
lines changed

1 file changed

+13
-23
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,33 +1436,23 @@ def enable_group_offload(
14361436
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
14371437
)
14381438

1439+
group_offload_kwargs = {
1440+
"onload_device": onload_device,
1441+
"offload_device": offload_device,
1442+
"offload_type": offload_type,
1443+
"num_blocks_per_group": num_blocks_per_group,
1444+
"non_blocking": non_blocking,
1445+
"use_stream": use_stream,
1446+
"record_stream": record_stream,
1447+
"low_cpu_mem_usage": low_cpu_mem_usage,
1448+
"offload_to_disk_path": offload_to_disk_path,
1449+
}
14391450
for name, component in self.components.items():
14401451
if name not in exclude_modules and isinstance(component, torch.nn.Module):
14411452
if hasattr(component, "enable_group_offload"):
1442-
component.enable_group_offload(
1443-
onload_device=onload_device,
1444-
offload_device=offload_device,
1445-
offload_type=offload_type,
1446-
num_blocks_per_group=num_blocks_per_group,
1447-
non_blocking=non_blocking,
1448-
use_stream=use_stream,
1449-
record_stream=record_stream,
1450-
low_cpu_mem_usage=low_cpu_mem_usage,
1451-
offload_to_disk_path=offload_to_disk_path,
1452-
)
1453+
component.enable_group_offload(**group_offload_kwargs)
14531454
else:
1454-
apply_group_offloading(
1455-
module=component,
1456-
onload_device=onload_device,
1457-
offload_device=offload_device,
1458-
offload_type=offload_type,
1459-
num_blocks_per_group=num_blocks_per_group,
1460-
non_blocking=non_blocking,
1461-
use_stream=use_stream,
1462-
record_stream=record_stream,
1463-
low_cpu_mem_usage=low_cpu_mem_usage,
1464-
offload_to_disk_path=offload_to_disk_path,
1465-
)
1455+
apply_group_offloading(module=component, **group_offload_kwargs)
14661456

14671457
if exclude_modules:
14681458
for module_name in exclude_modules:

0 commit comments

Comments
 (0)