Skip to content

Commit 4f4c551

Browse files
committed
Implement RAM Pressure cache
Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE.
1 parent 0c95f22 commit 4f4c551

File tree

4 files changed

+103
-7
lines changed

4 files changed

+103
-7
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
105105
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
106106
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
107107
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
108+
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
108109

109110
attn_group = parser.add_mutually_exclusive_group()
110111
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

comfy_execution/caching.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import bisect
2+
import gc
13
import itertools
4+
import psutil
5+
import time
6+
import torch
27
from typing import Sequence, Mapping, Dict
38
from comfy_execution.graph import DynamicPrompt
49
from abc import ABC, abstractmethod
@@ -188,6 +193,9 @@ def clean_unused(self):
188193
self._clean_cache()
189194
self._clean_subcaches()
190195

196+
def poll(self, **kwargs):
197+
pass
198+
191199
def _set_immediate(self, node_id, value):
192200
assert self.initialized
193201
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -276,6 +284,9 @@ def all_node_ids(self):
276284
def clean_unused(self):
277285
pass
278286

287+
def poll(self, **kwargs):
288+
pass
289+
279290
def get(self, node_id):
280291
return None
281292

@@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids):
336347
self._mark_used(child_id)
337348
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
338349
return self
350+
351+
352+
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
353+
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
354+
355+
RAM_CACHE_HYSTERESIS = 1.1
356+
357+
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
358+
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
359+
360+
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
361+
362+
#Exponential bias towards evicting older workflows so garbage will be taken out
363+
#in constantly changing setups.
364+
365+
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
366+
367+
class RAMPressureCache(LRUCache):
368+
369+
def __init__(self, key_class):
370+
super().__init__(key_class, 0)
371+
self.timestamps = {}
372+
373+
def clean_unused(self):
374+
self._clean_subcaches()
375+
376+
def set(self, node_id, value):
377+
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
378+
super().set(node_id, value)
379+
380+
def get(self, node_id):
381+
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
382+
return super().get(node_id)
383+
384+
def poll(self, ram_headroom):
385+
def _ram_gb():
386+
return psutil.virtual_memory().available / (1024**3)
387+
388+
if _ram_gb() > ram_headroom:
389+
return
390+
gc.collect()
391+
if _ram_gb() > ram_headroom:
392+
return
393+
394+
clean_list = []
395+
396+
for key, (outputs, _), in self.cache.items():
397+
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
398+
399+
ram_usage = 0.1
400+
def scan_list_for_ram_usage(outputs):
401+
nonlocal ram_usage
402+
for output in outputs:
403+
if isinstance(output, list):
404+
scan_list_for_ram_usage(output)
405+
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
406+
#score Tensors at a 50% discount for RAM usage as they are likely to
407+
#be high value intermediates
408+
ram_usage += (output.numel() * output.element_size()) * 0.5
409+
elif hasattr(output, "get_ram_usage"):
410+
ram_usage += output.get_ram_usage()
411+
scan_list_for_ram_usage(outputs)
412+
413+
oom_score *= ram_usage
414+
#In the case where we have no information on the node ram usage at all,
415+
#break OOM score ties on the last touch timestamp (pure LRU)
416+
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
417+
418+
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
419+
_, _, key = clean_list.pop()
420+
del self.cache[key]
421+
gc.collect()

execution.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NullCache,
2222
HierarchicalCache,
2323
LRUCache,
24+
RAMPressureCache,
2425
)
2526
from comfy_execution.graph import (
2627
DynamicPrompt,
@@ -92,16 +93,20 @@ class CacheType(Enum):
9293
CLASSIC = 0
9394
LRU = 1
9495
NONE = 2
96+
RAM_PRESSURE = 3
9597

9698

9799
class CacheSet:
98-
def __init__(self, cache_type=None, cache_size=None):
100+
def __init__(self, cache_type=None, cache_args={}):
99101
if cache_type == CacheType.NONE:
100102
self.init_null_cache()
101103
logging.info("Disabling intermediate node cache.")
104+
elif cache_type == CacheType.RAM_PRESSURE:
105+
cache_ram = cache_args.get("ram", 16.0)
106+
self.init_ram_cache(cache_ram)
107+
logging.info("Using RAM pressure cache.")
102108
elif cache_type == CacheType.LRU:
103-
if cache_size is None:
104-
cache_size = 0
109+
cache_size = cache_args.get("lru", 0)
105110
self.init_lru_cache(cache_size)
106111
logging.info("Using LRU cache")
107112
else:
@@ -118,6 +123,10 @@ def init_lru_cache(self, cache_size):
118123
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
119124
self.objects = HierarchicalCache(CacheKeySetID)
120125

126+
def init_ram_cache(self, min_headroom):
127+
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
128+
self.objects = HierarchicalCache(CacheKeySetID)
129+
121130
def init_null_cache(self):
122131
self.outputs = NullCache()
123132
self.objects = NullCache()
@@ -600,14 +609,14 @@ async def await_completion():
600609
return (ExecutionResult.SUCCESS, None, None)
601610

602611
class PromptExecutor:
603-
def __init__(self, server, cache_type=False, cache_size=None):
604-
self.cache_size = cache_size
612+
def __init__(self, server, cache_type=False, cache_args=None):
613+
self.cache_args = cache_args
605614
self.cache_type = cache_type
606615
self.server = server
607616
self.reset()
608617

609618
def reset(self):
610-
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
619+
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
611620
self.status_messages = []
612621
self.success = True
613622

@@ -705,6 +714,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
705714
execution_list.unstage_node_execution()
706715
else: # result == ExecutionResult.SUCCESS:
707716
execution_list.complete_node_execution()
717+
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
708718
else:
709719
# Only execute when the while-loop ends without break
710720
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
172172
cache_type = execution.CacheType.CLASSIC
173173
if args.cache_lru > 0:
174174
cache_type = execution.CacheType.LRU
175+
elif args.cache_ram > 0:
176+
cache_type = execution.CacheType.RAM_PRESSURE
175177
elif args.cache_none:
176178
cache_type = execution.CacheType.NONE
177179

178-
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
180+
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
179181
last_gc_collect = 0
180182
need_gc = False
181183
gc_collect_interval = 10.0

0 commit comments

Comments
 (0)