3232from torchvision import transforms
3333from PIL import Image , ImageDraw , ImageFont
3434import warnings
35+ import importlib
3536
3637import comfy .utils
3738import comfy .sd
4849from comfy_extras .nodes_custom_sampler import SamplerCustomAdvanced
4950from comfy_extras .nodes_flux import FluxGuidance
5051
51- from nodes import EmptyLatentImage
5252from nodes import CLIPTextEncode
5353
54+ # --- IMPORT KIJAI (THE GOAT) SAGE ATTENTION UNTIL COMFY CORE IMPLEMENTS A NODE
55+ SAGE_ATTENTION_INSTALLED = False
56+ KJNODES_INSTALLED = False
57+ PatchSageAttention = None
58+ sageattn_modes = ['disabled' ]
59+
60+ try :
61+ kj_module_path = importlib .import_module ('custom_nodes.ComfyUI-KJNodes.nodes.model_optimization_nodes' )
62+ KJNODES_INSTALLED = True
63+ print ('\t - 🟢 KJ Nodes available.' )
64+ except :
65+ print ('\t - 🚨 KJNODES NOT AVAILABLE' )
66+
67+ try :
68+ importlib .import_module ("sageattention" )
69+ SAGE_ATTENTION_INSTALLED = True
70+ print ('\t - 🟢 Sage Attention available.' )
71+ except :
72+ print ('\t - 🚨 SAGE ATTENTION NOT AVAILABLE' )
73+
74+
75+ if SAGE_ATTENTION_INSTALLED and KJNODES_INSTALLED :
76+ print ('\t - ✅ KJ Nodes & Sage Attention available.' )
77+ print ('\t - 🥳🎉 Activating Sage Attention for 🌊🚒 FlowState Flux Engine.' )
78+ PatchSageAttention = kj_module_path .PathchSageAttentionKJ ()
79+ sageattn_modes = kj_module_path .sageattn_modes
5480
55- warnings .filterwarnings ('ignore' , message = 'clean_up_tokenization_spaces' )
56- warnings .filterwarnings ('ignore' , message = 'Torch was not compiled with flash attention' )
57- warnings .filterwarnings ('ignore' , category = FutureWarning )
5881
5982
6083##
@@ -85,6 +108,7 @@ def INPUT_TYPES(s):
85108 'required' : {
86109 'model_name' : TYPE_DIFFUSION_MODELS_LIST (),
87110 'weight_dtype' : TYPE_WEIGHT_DTYPE ,
111+ 'sage_attention' : (sageattn_modes , ),
88112 'clip_1_name' : TYPE_CLIPS_LIST (),
89113 'clip_2_name' : TYPE_CLIPS_LIST (),
90114 'vae_name' : TYPE_VAES_LIST (),
@@ -304,8 +328,8 @@ def sample(self, sampler_components):
304328 return img_batch_out , latent_batch_out
305329
306330 def execute (
307- self , model_name , weight_dtype , clip_1_name , clip_2_name , vae_name , resolution , orientation , latent_type ,
308- custom_width , custom_height , custom_batch_size , image , seed , sampling_algorithm , scheduling_algorithm ,
331+ self , model_name , weight_dtype , sage_attention , clip_1_name , clip_2_name , vae_name , resolution , orientation ,
332+ latent_type , custom_width , custom_height , custom_batch_size , image , seed , sampling_algorithm , scheduling_algorithm ,
309333 guidance , steps , denoise , prompt , input_img = None
310334 ):
311335
@@ -321,7 +345,12 @@ def execute(
321345 f'\n - Loading { clip_2_name } ...'
322346 f'\n - Loading { vae_name } ...\n '
323347 )
324- self .loaded_model = UNETLoader ().load_unet (model_name , weight_dtype )[0 ]
348+ if sage_attention != 'disabled' :
349+ self .loaded_model = UNETLoader ().load_unet (model_name , weight_dtype )[0 ]
350+ self .loaded_model = PatchSageAttention .patch (self .loaded_model , sage_attention )[0 ]
351+ else :
352+ self .loaded_model = UNETLoader ().load_unet (model_name , weight_dtype )[0 ]
353+
325354 self .loaded_clip = DualCLIPLoader ().load_clip (clip_1_name , clip_2_name , 'flux' , 'default' )[0 ]
326355 self .loaded_vae = VAELoader ().load_vae (vae_name )[0 ]
327356 else :
0 commit comments