Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def model_size(self):
self.size = comfy.model_management.module_size(self.model)
return self.size

def get_ram_usage(self):
return self.model_size()

def loaded_size(self):
return self.model.model_loaded_weight_memory

Expand Down
14 changes: 14 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def clone(self):
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n

def get_ram_usage(self):
return self.patcher.get_ram_usage()

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)

Expand Down Expand Up @@ -293,6 +296,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None

self.downscale_index_formula = None
self.upscale_index_formula = None
Expand Down Expand Up @@ -595,6 +599,16 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):

self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()

def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size

def get_ram_usage(self):
return self.model_size()

def throw_exception_if_invalid(self):
if self.first_stage_model is None:
Expand Down
83 changes: 83 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -188,6 +193,9 @@ def clean_unused(self):
self._clean_cache()
self._clean_subcaches()

def poll(self, **kwargs):
pass

def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
Expand Down Expand Up @@ -276,6 +284,9 @@ def all_node_ids(self):
def clean_unused(self):
pass

def poll(self, **kwargs):
pass

def get(self, node_id):
return None

Expand Down Expand Up @@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self


#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.

RAM_CACHE_HYSTERESIS = 1.1

#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK

RAM_CACHE_DEFAULT_RAM_USAGE = 0.1

#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.

RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3

class RAMPressureCache(LRUCache):

def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}

def clean_unused(self):
self._clean_subcaches()

def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)

def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)

def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)

if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return

clean_list = []

for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])

ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)

oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))

while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
9 changes: 7 additions & 2 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,15 @@ def cache_link(self, from_node_id, to_node_id):
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)

def get_output_cache(self, from_node_id, to_node_id):
def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value

def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
Expand Down
81 changes: 46 additions & 35 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
Expand Down Expand Up @@ -88,49 +89,56 @@ async def get(self, node_id):
return self.is_changed[node_id]


class CacheEntry(NamedTuple):
ui: dict
outputs: list


class CacheType(Enum):
CLASSIC = 0
LRU = 1
NONE = 2
RAM_PRESSURE = 3


class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()

self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]

# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()

def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result

Expand All @@ -157,14 +165,14 @@ def mark_missing():
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
mark_missing()
continue
if output_index >= len(cached_output):
if output_index >= len(cached.outputs):
mark_missing()
continue
obj = cached_output[output_index]
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
Expand Down Expand Up @@ -393,20 +401,23 @@ def format_value(x):
else:
return str(x)

async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)

input_data_all = None
Expand Down Expand Up @@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
resolved_output.append(o)

else:
Expand Down Expand Up @@ -506,15 +517,15 @@ async def await_completion():
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
Expand Down Expand Up @@ -557,8 +568,9 @@ async def await_completion():
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)

caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)

except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
Expand Down Expand Up @@ -603,14 +615,14 @@ async def await_completion():
return (ExecutionResult.SUCCESS, None, None)

class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()

def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True

Expand Down Expand Up @@ -685,6 +697,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
Expand All @@ -698,7 +711,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
break

assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
Expand All @@ -707,18 +720,16 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
Expand Down
Loading