Skip to content

Commit 665f74a

Browse files
authored
remove redundant empty_cache in parallel forward (#161)
1 parent 8b056d8 commit 665f74a

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

diffsynth_engine/utils/parallel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,9 @@ def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
295295

296296
if (name := data[0]) == "unload_module":
297297
module = None
298+
empty_cache()
298299
elif name == "load_module":
299-
init_fn, kwargs = to_device(data[1:], device=device)
300+
init_fn, kwargs = data[1:]
300301
module = wrap_for_parallel(init_fn(**kwargs))
301302
elif module is None:
302303
res = RuntimeError("module is not initialized")
@@ -307,12 +308,10 @@ def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
307308
with torch.no_grad():
308309
res = getattr(module, name)(*args, **kwargs)
309310

310-
data, args, kwargs = None, None, None
311-
torch.cuda.synchronize()
312-
empty_cache()
313-
dist.barrier()
314311
if rank == 0:
315312
queue_out.put(res)
313+
data, args, kwargs = None, None, None
314+
dist.barrier()
316315
except Exception:
317316
import traceback
318317

0 commit comments

Comments
 (0)