|
| 1 | +import bisect |
| 2 | +import gc |
1 | 3 | import itertools |
| 4 | +import psutil |
| 5 | +import time |
| 6 | +import torch |
2 | 7 | from typing import Sequence, Mapping, Dict |
3 | 8 | from comfy_execution.graph import DynamicPrompt |
4 | 9 | from abc import ABC, abstractmethod |
@@ -188,6 +193,9 @@ def clean_unused(self): |
188 | 193 | self._clean_cache() |
189 | 194 | self._clean_subcaches() |
190 | 195 |
|
| 196 | + def poll(self, **kwargs): |
| 197 | + pass |
| 198 | + |
191 | 199 | def _set_immediate(self, node_id, value): |
192 | 200 | assert self.initialized |
193 | 201 | cache_key = self.cache_key_set.get_data_key(node_id) |
@@ -276,6 +284,9 @@ def all_node_ids(self): |
276 | 284 | def clean_unused(self): |
277 | 285 | pass |
278 | 286 |
|
| 287 | + def poll(self, **kwargs): |
| 288 | + pass |
| 289 | + |
279 | 290 | def get(self, node_id): |
280 | 291 | return None |
281 | 292 |
|
@@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids): |
336 | 347 | self._mark_used(child_id) |
337 | 348 | self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) |
338 | 349 | return self |
| 350 | + |
| 351 | + |
| 352 | +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure |
| 353 | +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. |
| 354 | + |
| 355 | +RAM_CACHE_HYSTERESIS = 1.1 |
| 356 | + |
| 357 | +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic |
| 358 | +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK |
| 359 | + |
| 360 | +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 |
| 361 | + |
| 362 | +#Exponential bias towards evicting older workflows so garbage will be taken out |
| 363 | +#in constantly changing setups. |
| 364 | + |
| 365 | +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 |
| 366 | + |
| 367 | +class RAMPressureCache(LRUCache): |
| 368 | + |
| 369 | + def __init__(self, key_class): |
| 370 | + super().__init__(key_class, 0) |
| 371 | + self.timestamps = {} |
| 372 | + |
| 373 | + def clean_unused(self): |
| 374 | + self._clean_subcaches() |
| 375 | + |
| 376 | + def set(self, node_id, value): |
| 377 | + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() |
| 378 | + super().set(node_id, value) |
| 379 | + |
| 380 | + def get(self, node_id): |
| 381 | + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() |
| 382 | + return super().get(node_id) |
| 383 | + |
| 384 | + def poll(self, ram_headroom): |
| 385 | + def _ram_gb(): |
| 386 | + return psutil.virtual_memory().available / (1024**3) |
| 387 | + |
| 388 | + if _ram_gb() > ram_headroom: |
| 389 | + return |
| 390 | + gc.collect() |
| 391 | + if _ram_gb() > ram_headroom: |
| 392 | + return |
| 393 | + |
| 394 | + clean_list = [] |
| 395 | + |
| 396 | + for key, (outputs, _), in self.cache.items(): |
| 397 | + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) |
| 398 | + |
| 399 | + ram_usage = 0.1 |
| 400 | + def scan_list_for_ram_usage(outputs): |
| 401 | + nonlocal ram_usage |
| 402 | + for output in outputs: |
| 403 | + if isinstance(output, list): |
| 404 | + scan_list_for_ram_usage(output) |
| 405 | + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': |
| 406 | + #score Tensors at a 50% discount for RAM usage as they are likely to |
| 407 | + #be high value intermediates |
| 408 | + ram_usage += (output.numel() * output.element_size()) * 0.5 |
| 409 | + elif hasattr(output, "get_ram_usage"): |
| 410 | + ram_usage += output.get_ram_usage() |
| 411 | + scan_list_for_ram_usage(outputs) |
| 412 | + |
| 413 | + oom_score *= ram_usage |
| 414 | + #In the case where we have no information on the node ram usage at all, |
| 415 | + #break OOM score ties on the last touch timestamp (pure LRU) |
| 416 | + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) |
| 417 | + |
| 418 | + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: |
| 419 | + _, _, key = clean_list.pop() |
| 420 | + del self.cache[key] |
| 421 | + gc.collect() |
0 commit comments