Skip to content

Commit d2b87e6

Browse files
committed
execution: Convert the cache entry to NamedTuple
As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py.
1 parent f3f526f commit d2b87e6

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

comfy_execution/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ def cache_link(self, from_node_id, to_node_id):
209209
self.execution_cache_listeners[from_node_id] = set()
210210
self.execution_cache_listeners[from_node_id].add(to_node_id)
211211

212-
def get_output_cache(self, from_node_id, to_node_id):
212+
def get_cache(self, from_node_id, to_node_id):
213213
if not to_node_id in self.execution_cache:
214214
return None
215215
value = self.execution_cache[to_node_id].get(from_node_id)
216216
if value is None:
217217
return None
218218
#Write back to the main cache on touch.
219219
self.output_cache.set(from_node_id, value)
220-
return value[0]
220+
return value
221221

222222
def cache_update(self, node_id, value):
223223
if node_id in self.execution_cache_listeners:

execution.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ async def get(self, node_id):
8989
return self.is_changed[node_id]
9090

9191

92+
class CacheEntry(NamedTuple):
93+
ui: dict
94+
outputs: list
95+
96+
9297
class CacheType(Enum):
9398
CLASSIC = 0
9499
LRU = 1
@@ -160,14 +165,14 @@ def mark_missing():
160165
if execution_list is None:
161166
mark_missing()
162167
continue # This might be a lazily-evaluated input
163-
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
164-
if cached_output is None:
168+
cached = execution_list.get_cache(input_unique_id, unique_id)
169+
if cached is None or cached.outputs is None:
165170
mark_missing()
166171
continue
167-
if output_index >= len(cached_output):
172+
if output_index >= len(cached.outputs):
168173
mark_missing()
169174
continue
170-
obj = cached_output[output_index]
175+
obj = cached.outputs[output_index]
171176
input_data_all[x] = obj
172177
elif input_category is not None:
173178
input_data_all[x] = [input_data]
@@ -407,10 +412,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
407412
cached = caches.outputs.get(unique_id)
408413
if cached is not None:
409414
if server.client_id is not None:
410-
cached_ui = cached[1] or {}
415+
cached_ui = cached.ui or {}
411416
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
412-
if cached[1] is not None:
413-
ui_outputs[unique_id] = cached[1]
417+
if cached.ui is not None:
418+
ui_outputs[unique_id] = cached.ui
414419
get_progress_state().finish_progress(unique_id)
415420
execution_list.cache_update(unique_id, cached)
416421
return (ExecutionResult.SUCCESS, None, None)
@@ -442,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
442447
for r in result:
443448
if is_link(r):
444449
source_node, source_output = r[0], r[1]
445-
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
446-
for o in node_output:
450+
node_cached = execution_list.get_cache(source_node, unique_id)
451+
for o in node_cached.outputs[source_output]:
447452
resolved_output.append(o)
448453

449454
else:
@@ -563,8 +568,9 @@ async def await_completion():
563568
pending_subgraph_results[unique_id] = cached_outputs
564569
return (ExecutionResult.PENDING, None, None)
565570

566-
execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
567-
caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))
571+
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
572+
execution_list.cache_update(unique_id, cache_entry)
573+
caches.outputs.set(unique_id, cache_entry)
568574

569575
except comfy.model_management.InterruptProcessingException as iex:
570576
logging.info("Processing interrupted")

0 commit comments

Comments
 (0)