@@ -1436,33 +1436,23 @@ def enable_group_offload(
1436
1436
f"The following modules are not present in pipeline: { ', ' .join (unknown )} . Ignore if this is expected."
1437
1437
)
1438
1438
1439
+ group_offload_kwargs = {
1440
+ "onload_device" : onload_device ,
1441
+ "offload_device" : offload_device ,
1442
+ "offload_type" : offload_type ,
1443
+ "num_blocks_per_group" : num_blocks_per_group ,
1444
+ "non_blocking" : non_blocking ,
1445
+ "use_stream" : use_stream ,
1446
+ "record_stream" : record_stream ,
1447
+ "low_cpu_mem_usage" : low_cpu_mem_usage ,
1448
+ "offload_to_disk_path" : offload_to_disk_path ,
1449
+ }
1439
1450
for name , component in self .components .items ():
1440
1451
if name not in exclude_modules and isinstance (component , torch .nn .Module ):
1441
1452
if hasattr (component , "enable_group_offload" ):
1442
- component .enable_group_offload (
1443
- onload_device = onload_device ,
1444
- offload_device = offload_device ,
1445
- offload_type = offload_type ,
1446
- num_blocks_per_group = num_blocks_per_group ,
1447
- non_blocking = non_blocking ,
1448
- use_stream = use_stream ,
1449
- record_stream = record_stream ,
1450
- low_cpu_mem_usage = low_cpu_mem_usage ,
1451
- offload_to_disk_path = offload_to_disk_path ,
1452
- )
1453
+ component .enable_group_offload (** group_offload_kwargs )
1453
1454
else :
1454
- apply_group_offloading (
1455
- module = component ,
1456
- onload_device = onload_device ,
1457
- offload_device = offload_device ,
1458
- offload_type = offload_type ,
1459
- num_blocks_per_group = num_blocks_per_group ,
1460
- non_blocking = non_blocking ,
1461
- use_stream = use_stream ,
1462
- record_stream = record_stream ,
1463
- low_cpu_mem_usage = low_cpu_mem_usage ,
1464
- offload_to_disk_path = offload_to_disk_path ,
1465
- )
1455
+ apply_group_offloading (module = component , ** group_offload_kwargs )
1466
1456
1467
1457
if exclude_modules :
1468
1458
for module_name in exclude_modules :
0 commit comments