From 97de521c70185623bd266dafbd16d1c1842063bc Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 01:01:37 +0000 Subject: [PATCH 01/14] Add build_line(...) util function. --- invokeai/backend/util/build_line.py | 6 ++++++ tests/backend/util/test_build_line.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 invokeai/backend/util/build_line.py create mode 100644 tests/backend/util/test_build_line.py diff --git a/invokeai/backend/util/build_line.py b/invokeai/backend/util/build_line.py new file mode 100644 index 00000000000..77cf98d8df6 --- /dev/null +++ b/invokeai/backend/util/build_line.py @@ -0,0 +1,6 @@ +from typing import Callable + + +def build_line(x1: float, y1: float, x2: float, y2: float) -> Callable[[float], float]: + """Build a linear function given two points on the line (x1, y1) and (x2, y2).""" + return lambda x: (y2 - y1) / (x2 - x1) * (x - x1) + y1 diff --git a/tests/backend/util/test_build_line.py b/tests/backend/util/test_build_line.py new file mode 100644 index 00000000000..9ed115cfe90 --- /dev/null +++ b/tests/backend/util/test_build_line.py @@ -0,0 +1,19 @@ +import math + +import pytest + +from invokeai.backend.util.build_line import build_line + + +@pytest.mark.parametrize( + ["x1", "y1", "x2", "y2", "x3", "y3"], + [ + (0, 0, 1, 1, 2, 2), # y = x + (0, 1, 1, 2, 2, 3), # y = x + 1 + (0, 0, 1, 2, 2, 4), # y = 2x + (0, 1, 1, 0, 2, -1), # y = -x + 1 + (0, 5, 1, 5, 2, 5), # y = 0 + ], +) +def test_build_line(x1: float, y1: float, x2: float, y2: float, x3: float, y3: float): + assert math.isclose(build_line(x1, y1, x2, y2)(x3), y3, rel_tol=1e-9) From 93c15c9958e106f209e779c698956275ec377cf9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 02:21:47 +0000 Subject: [PATCH 02/14] Rough draft of TrajectoryGuidanceExtension. --- .../flux/trajectory_guidance_extension.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 invokeai/backend/flux/trajectory_guidance_extension.py diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py new file mode 100644 index 00000000000..b126c16f454 --- /dev/null +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -0,0 +1,90 @@ +import torch + +from invokeai.backend.util.build_line import build_line + + +class TrajectoryGuidanceExtension: + def __init__( + self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor | None, trajectory_guidance_strength: float + ): + """Initialize TrajectoryGuidanceExtension. + + Args: + init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format. + inpaint_mask (torch.Tensor | None): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 + will be re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the + inpainted region with the background. In 'packed' format. If None, will be treated as a mask of all 1s. + trajectory_guidance_strength (float): A value in [0, 1] specifying the strength of the trajectory guidance. + A value of 0.0 is equivalent to vanilla image-to-image. A value of 1.0 will guide the denoising process + very close to the original latents. + """ + assert 0.0 <= trajectory_guidance_strength <= 1.0 + self._init_latents = init_latents + self._trajectory_guidance_strength = trajectory_guidance_strength + if inpaint_mask is None: + # The inpaing mask is None, so we initialize a mask with a single value of 1.0. + # This value will be broadcasted and treated as a mask of all 1s. + self._inpaint_mask = torch.ones(1, device=init_latents.device, dtype=init_latents.dtype) + else: + self._inpaint_mask = self._inpaint_mask + + def step( + self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float + ) -> torch.Tensor: + # Handle gradient cutoff. + # TODO(ryand): This logic is a bit arbitrary. Think about how to clean it up. + timestep_cutoff = 0.5 + if t_prev > timestep_cutoff: + # Early in the denoising process, use the smaller mask. + # I.e. treat gradient values as 0.0. + mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0) + else: + # After the cut-off, use the larger mask. + # I.e. treat gradient values as 1.0. + mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0) + # mask = (self._inpaint_mask > (0.0 + 1e-5)).float() + + # Calculate the change_ratio based on the trajectory guidance strength. + change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) + change_ratio_at_cutoff = 1.0 + t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) + change_ratio = 1.0 + if t_prev > t_cutoff: + # If we are before the cutoff, lineaarly interpolate between the change_ratio at t=1.0 and the change_ratio + # at the cutoff. + change_ratio = build_line(x1=1.0, y1=change_ratio_at_t_1, x2=t_cutoff, y2=change_ratio_at_cutoff)(t_prev) + # change_ratio = change_ratio_at_t_1 + (change_ratio_at_cutoff - change_ratio_at_t_1) * (1.0 - t_prev) / ( + # 1.0 - t_cutoff + # ) + + mask = mask * change_ratio + + # NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for + # the current timestep and then blending the predicted intermediate latents with the noised initial latents. + # For example, we could do this here with something like: + # ``` + # noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents + # return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask) + # ``` + # Instead of guiding based on the noised initial latents, we have decided to guide based on the noise prediction + # that points towards the initial latents. The difference between these guidance strategies is minor, but + # qualitatively we found the latter to produce slightly better results. When the guidance strength is 0.0 or 1.0 + # there is no difference between the two strategies. + # + # We experimented with a number of related guidance strategies, but not exhaustively. It's entirely possible + # that there's a much better way to do this. + + # Calculate noise guidance + # What noise should the model have predicted at this timestep to step towards self._init_latents? + # Derivation: + # > t_prev_latents = t_curr_latents + (t_prev - t_curr) * pred_noise + # > t_0_latents = t_curr_latents + (0 - t_curr) * init_traj_noise + # > t_0_latents = t_curr_latents - t_curr * init_traj_noise + # > init_traj_noise = (t_curr_latents - t_0_latents) / t_curr) + init_traj_noise = (t_curr_latents - self._init_latents) / t_curr + + # Blend the init_traj_noise with the pred_noise according to the inpaint mask and the trajectory guidance. + noise = pred_noise * mask + init_traj_noise * (1.0 - mask) + + # Take a denoising step. + return t_curr_latents + (t_prev - t_curr) * noise From e8357afd3a953589ccfa16307517bc6d5f87051b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 02:41:52 +0000 Subject: [PATCH 03/14] Add traj_guidance_strength to FluxDenoiseInvocation. --- invokeai/app/invocations/flux_denoise.py | 21 +++++++++++++-------- invokeai/backend/flux/denoise.py | 16 +++++++++------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 7035d62f365..840a615ce5e 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -20,7 +20,6 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.denoise import denoise -from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux from invokeai.backend.flux.sampling_utils import ( clip_timestep_schedule, @@ -30,6 +29,7 @@ pack, unpack, ) +from invokeai.backend.flux.trajectory_guidance_extension import TrajectoryGuidanceExtension from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat @@ -68,6 +68,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + trajectory_guidance_strength: float = InputField( + default=0.0, + ge=0.0, + le=1.0, + description="Value indicating how strongly to guide the denoising process towards the initial latents (during image-to-image). Range [0, 1]. A value of 0.0 is equivalent to vanilla image-to-image. A value of 1.0 will guide the denoising process very close to the original latents.", + ) transformer: TransformerField = InputField( description=FieldDescriptions.flux_model, input=Input.Connection, @@ -181,14 +187,13 @@ def _run_diffusion( # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly. assert image_seq_len == x.shape[1] - # Prepare inpaint extension. - inpaint_extension: InpaintExtension | None = None - if inpaint_mask is not None: - assert init_latents is not None - inpaint_extension = InpaintExtension( + # Prepare trajectory guidance extension. + traj_guidance_extension: TrajectoryGuidanceExtension | None = None + if init_latents is not None: + traj_guidance_extension = TrajectoryGuidanceExtension( init_latents=init_latents, inpaint_mask=inpaint_mask, - noise=noise, + trajectory_guidance_strength=self.trajectory_guidance_strength, ) with ( @@ -236,7 +241,7 @@ def _run_diffusion( timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, - inpaint_extension=inpaint_extension, + traj_guidance_extension=traj_guidance_extension, ) x = unpack(x.float(), self.height, self.width) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index ec0d238b410..e837dfc7607 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -3,8 +3,8 @@ import torch from tqdm import tqdm -from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.trajectory_guidance_extension import TrajectoryGuidanceExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -20,7 +20,7 @@ def denoise( timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], guidance: float, - inpaint_extension: InpaintExtension | None, + traj_guidance_extension: TrajectoryGuidanceExtension | None, # noqa: F821 ): step = 0 # guidance_vec is ignored for schnell. @@ -36,12 +36,14 @@ def denoise( timesteps=t_vec, guidance=guidance_vec, ) - preview_img = img - t_curr * pred - img = img + (t_prev - t_curr) * pred - if inpaint_extension is not None: - img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) - preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) + if traj_guidance_extension is not None: + img = traj_guidance_extension.step(t_curr_latents=img, pred_noise=pred, t_curr=t_curr, t_prev=t_prev) + # TODO(ryand): Generate a better preview image. + preview_img = img + else: + preview_img = img - t_curr * pred + img = img + (t_prev - t_curr) * pred step_callback( PipelineIntermediateState( From f0aad5882de0086b1e211d657f16d3bee452abb9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 14:04:53 +0000 Subject: [PATCH 04/14] Fixup docs in the TrajectoryGuidanceExtension. --- .../flux/trajectory_guidance_extension.py | 46 +++++++++++++++---- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index b126c16f454..47834170f38 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -4,6 +4,35 @@ class TrajectoryGuidanceExtension: + """An implementation of trajectory guidance for FLUX. + + What is trajectory guidance? + ---------------------------- + With SD 1 and SDXL, the amount of change in image-to-image denoising is largely controlled by the denoising_start + parameter. Doing the same thing with the FLUX model does not work as well, because the FLUX model converges very + quickly (roughly time 1.0 to 0.9) to the structure of the final image. The result of this model characteristic is + that you typically get one of two outcomes: + 1) a result that is very similar to the original image + 2) a result that is very different from the original image, as though it was generated from the text prompt with + pure noise. + + To address this issue with image-to-image workflows with FLUX, we employ the concept of trajectory guidance. The + idea is that in addition to controlling the denoising_start parameter (i.e. the amount of noise added to the + original image), we can also guide the denoising process to stay close to the trajectory that would reproduce the + original. By controlling the strength of the trajectory guidance throughout the denoising process, we can achieve + FLUX image-to-image behavior with the same level of control offered by SD1 and SDXL. + + What is the trajectory_guidance_strength? + ----------------------------------------- + In the limit, we could apply a different trajectory guidance 'strength' for every latent value in every timestep. + This would be impractical for a user, so instead we have engineered a strength schedule that is more convenient to + use. The `trajectory_guidance_strength` parameter is a single scalar value that maps to a schedule. The engineered + schedule is defined as: + 1) An initial change_ratio at t=1.0. + 2) A linear ramp up to change_ratio=1.0 at t = t_cutoff. + 3) A constant change_ratio=1.0 after t = t_cutoff. + """ + def __init__( self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor | None, trajectory_guidance_strength: float ): @@ -44,32 +73,29 @@ def step( mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0) # mask = (self._inpaint_mask > (0.0 + 1e-5)).float() - # Calculate the change_ratio based on the trajectory guidance strength. + # Calculate the change_ratio based on the trajectory_guidance_strength. change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) change_ratio_at_cutoff = 1.0 t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) change_ratio = 1.0 if t_prev > t_cutoff: - # If we are before the cutoff, lineaarly interpolate between the change_ratio at t=1.0 and the change_ratio + # If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio # at the cutoff. change_ratio = build_line(x1=1.0, y1=change_ratio_at_t_1, x2=t_cutoff, y2=change_ratio_at_cutoff)(t_prev) - # change_ratio = change_ratio_at_t_1 + (change_ratio_at_cutoff - change_ratio_at_t_1) * (1.0 - t_prev) / ( - # 1.0 - t_cutoff - # ) mask = mask * change_ratio # NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for # the current timestep and then blending the predicted intermediate latents with the noised initial latents. - # For example, we could do this here with something like: + # For example: # ``` - # noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents - # return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask) + # noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents + # return t_prev_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask) # ``` # Instead of guiding based on the noised initial latents, we have decided to guide based on the noise prediction # that points towards the initial latents. The difference between these guidance strategies is minor, but - # qualitatively we found the latter to produce slightly better results. When the guidance strength is 0.0 or 1.0 - # there is no difference between the two strategies. + # qualitatively we found the latter to produce slightly better results. When change_ratio is 0.0 or 1.0 there is + # no difference between the two strategies. # # We experimented with a number of related guidance strategies, but not exhaustively. It's entirely possible # that there's a much better way to do this. From b6748fb1e10d2b68f5c1841d5e37a4a1869b740b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 14:15:59 +0000 Subject: [PATCH 05/14] Fix typo --- invokeai/backend/flux/trajectory_guidance_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index 47834170f38..39e1a3501f1 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -55,7 +55,7 @@ def __init__( # This value will be broadcasted and treated as a mask of all 1s. self._inpaint_mask = torch.ones(1, device=init_latents.device, dtype=init_latents.dtype) else: - self._inpaint_mask = self._inpaint_mask + self._inpaint_mask = inpaint_mask def step( self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float From 2f82171dffc81da9f2fada0590bfd80360296a70 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 14:48:06 +0000 Subject: [PATCH 06/14] Tidy up the logic for inpainting mask adjustment in FLUX TrajectoryGuidanceExtension. --- .../flux/trajectory_guidance_extension.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index 39e1a3501f1..9b779e33382 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -57,23 +57,32 @@ def __init__( else: self._inpaint_mask = inpaint_mask + def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor: + """Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep.""" + # As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of + # 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small) + # number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff + # threshold worked well. + # We use a small epsilon to avoid any potential issues with floating point precision. + eps = 1e-4 + mask_gradient_t_cutoff = 0.5 + if t_prev > mask_gradient_t_cutoff: + # Early in the denoising process, use the inpaint mask as-is. + return self._inpaint_mask + else: + # After the cut-off, promote all non-zero mask values to 1.0. + mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0) + + return mask + def step( self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float ) -> torch.Tensor: # Handle gradient cutoff. - # TODO(ryand): This logic is a bit arbitrary. Think about how to clean it up. - timestep_cutoff = 0.5 - if t_prev > timestep_cutoff: - # Early in the denoising process, use the smaller mask. - # I.e. treat gradient values as 0.0. - mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0) - else: - # After the cut-off, use the larger mask. - # I.e. treat gradient values as 1.0. - mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0) - # mask = (self._inpaint_mask > (0.0 + 1e-5)).float() + mask = self._apply_mask_gradient_adjustment(t_prev) # Calculate the change_ratio based on the trajectory_guidance_strength. + # These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually. change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) change_ratio_at_cutoff = 1.0 t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) From d3d1b49ff263369542a56e33d0780874647c2fc8 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 20 Sep 2024 14:29:40 -0400 Subject: [PATCH 07/14] feat(ui): add optimized denoising toggle to linear UI for FLUX image to image and inpainting --- invokeai/frontend/web/public/locales/en.json | 7 ++++ .../InformationalPopover/constants.ts | 3 +- .../controlLayers/store/paramsSlice.ts | 7 ++++ .../util/graph/generation/buildFLUXGraph.ts | 41 ++++++++++++++++--- .../ParamOptimizedDenoisingToggle.tsx | 35 ++++++++++++++++ .../ImageSettingsAccordion.tsx | 5 ++- .../frontend/web/src/services/api/schema.ts | 10 ++++- 7 files changed, 99 insertions(+), 9 deletions(-) create mode 100644 invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 4165555e0ae..ba31249801b 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1040,6 +1040,7 @@ "strength": "Strength", "symmetry": "Symmetry", "tileSize": "Tile Size", + "optimizedDenoising": "Optimized Denoising", "type": "Type", "postProcessing": "Post-Processing (Shift + U)", "processImage": "Process Image", @@ -1539,6 +1540,12 @@ "paragraphs": [ "Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout." ] + }, + "optimizedDenoising": { + "heading": "Optimized Denoising", + "paragraphs": [ + "Enable optimized denoising for enhanced image-to-image transformations with Flux models. This setting improves detail and clarity during generation, but may be turned off to preserve more of your original image." + ] } }, "unifiedCanvas": { diff --git a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts index fae89007872..93496ee4af8 100644 --- a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts +++ b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts @@ -58,7 +58,8 @@ export type Feature = | 'upscaleModel' | 'scale' | 'creativity' - | 'structure'; + | 'structure' + | 'optimizedDenoising'; export type PopoverData = PopoverProps & { image?: string; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 672dac5c9a6..4dd4b963559 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -40,6 +40,7 @@ export type ParamsState = { cfgRescaleMultiplier: ParameterCFGRescaleMultiplier; guidance: ParameterGuidance; img2imgStrength: ParameterStrength; + optimizedDenoisingEnabled: boolean; iterations: number; scheduler: ParameterScheduler; seed: ParameterSeed; @@ -83,6 +84,7 @@ const initialState: ParamsState = { cfgRescaleMultiplier: 0, guidance: 4, img2imgStrength: 0.75, + optimizedDenoisingEnabled: true, iterations: 1, scheduler: 'euler', seed: 0, @@ -141,6 +143,9 @@ export const paramsSlice = createSlice({ setImg2imgStrength: (state, action: PayloadAction) => { state.img2imgStrength = action.payload; }, + setOptimizedDenoisingEnabled: (state, action: PayloadAction) => { + state.optimizedDenoisingEnabled = action.payload; + }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -273,6 +278,7 @@ export const { setScheduler, setSeed, setImg2imgStrength, + setOptimizedDenoisingEnabled, setSeamlessXAxis, setSeamlessYAxis, setShouldRandomizeSeed, @@ -341,6 +347,7 @@ export const selectInfillPatchmatchDownscaleSize = createParamsSelector( ); export const selectInfillColorValue = createParamsSelector((params) => params.infillColorValue); export const selectImg2imgStrength = createParamsSelector((params) => params.img2imgStrength); +export const selectOptimizedDenoisingEnabled = createParamsSelector((params) => params.optimizedDenoisingEnabled); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectPositivePrompt2 = createParamsSelector((params) => params.positivePrompt2); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index fc551c49df4..801d6499cb7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -37,7 +37,17 @@ export const buildFLUXGraph = async ( const { originalSize, scaledSize } = getSizes(bbox); - const { model, guidance, seed, steps, fluxVAE, t5EncoderModel, clipEmbedModel, img2imgStrength } = params; + const { + model, + guidance, + seed, + steps, + fluxVAE, + t5EncoderModel, + clipEmbedModel, + img2imgStrength, + optimizedDenoisingEnabled, + } = params; assert(model, 'No model found in state'); assert(t5EncoderModel, 'No T5 Encoder model found in state'); @@ -68,7 +78,8 @@ export const buildFLUXGraph = async ( guidance, num_steps: steps, seed, - denoising_start: 0, // denoising_start should be 0 when latents are not provided + trajectory_guidance_strength: 0, + denoising_start: 0, denoising_end: 1, width: scaledSize.width, height: scaledSize.height, @@ -113,6 +124,8 @@ export const buildFLUXGraph = async ( clip_embed_model: clipEmbedModel, }); + const denoisingStart = 1 - img2imgStrength; + if (generationMode === 'txt2img') { canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize); } else if (generationMode === 'img2img') { @@ -125,9 +138,15 @@ export const buildFLUXGraph = async ( originalSize, scaledSize, bbox, - 1 - img2imgStrength, + denoisingStart, false ); + if (optimizedDenoisingEnabled) { + g.updateNode(noise, { + denoising_start: 0, + trajectory_guidance_strength: img2imgStrength, + }); + } } else if (generationMode === 'inpaint') { canvasOutput = await addInpaint( state, @@ -139,9 +158,15 @@ export const buildFLUXGraph = async ( modelLoader, originalSize, scaledSize, - 1 - img2imgStrength, + denoisingStart, false ); + if (optimizedDenoisingEnabled) { + g.updateNode(noise, { + denoising_start: 0, + trajectory_guidance_strength: img2imgStrength, + }); + } } else if (generationMode === 'outpaint') { canvasOutput = await addOutpaint( state, @@ -153,9 +178,15 @@ export const buildFLUXGraph = async ( modelLoader, originalSize, scaledSize, - 1 - img2imgStrength, + denoisingStart, false ); + if (optimizedDenoisingEnabled) { + g.updateNode(noise, { + denoising_start: 0, + trajectory_guidance_strength: img2imgStrength, + }); + } } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx new file mode 100644 index 00000000000..85b69fff4b4 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx @@ -0,0 +1,35 @@ +import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { + selectOptimizedDenoisingEnabled, + setOptimizedDenoisingEnabled, +} from 'features/controlLayers/store/paramsSlice'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const ParamOptimizedDenoisingToggle = memo(() => { + const optimizedDenoisingEnabled = useAppSelector(selectOptimizedDenoisingEnabled); + const dispatch = useAppDispatch(); + + const onChange = useCallback( + (event: ChangeEvent) => { + dispatch(setOptimizedDenoisingEnabled(event.target.checked)); + }, + [dispatch] + ); + + const { t } = useTranslation(); + + return ( + + + {t('parameters.optimizedDenoising')} + + + + ); +}); + +ParamOptimizedDenoisingToggle.displayName = 'ParamOptimizedDenoisingToggle'; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx index 10f6f05269c..3022f3f60e9 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx @@ -3,8 +3,9 @@ import { Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-a import { EMPTY_ARRAY } from 'app/store/constants'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { selectIsFLUX, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import { ParamOptimizedDenoisingToggle } from 'features/parameters/components/Advanced/ParamOptimizedDenoisingToggle'; import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeight'; import BboxScaledWidth from 'features/parameters/components/Bbox/BboxScaledWidth'; import BboxScaleMethod from 'features/parameters/components/Bbox/BboxScaleMethod'; @@ -58,6 +59,7 @@ export const ImageSettingsAccordion = memo(() => { id: 'image-settings-advanced', defaultIsOpen: false, }); + const isFLUX = useAppSelector(selectIsFLUX); return ( { + {isFLUX && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 59006b21010..e24bf2c0baf 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5706,7 +5706,7 @@ export type components = { }; /** * FLUX LoRA Collection Loader - * @description Applies a collection of FLUX LoRAs to the provided UNet and CLIP models. + * @description Applies a collection of LoRAs to a FLUX transformer. */ FLUXLoRACollectionLoader: { /** @@ -6371,6 +6371,12 @@ export type components = { * @default 1 */ denoising_end?: number; + /** + * Trajectory Guidance Strength + * @description Value indicating how strongly to guide the denoising process towards the initial latents (during image-to-image). Range [0, 1]. A value of 0.0 is equivalent to vanilla image-to-image. A value of 1.0 will guide the denoising process very close to the original latents. + * @default 0 + */ + trajectory_guidance_strength?: number; /** * Transformer * @description Flux model (Transformer) to load @@ -6478,7 +6484,7 @@ export type components = { * @description Transformer * @default null */ - transformer: components["schemas"]["TransformerField"]; + transformer: components["schemas"]["TransformerField"] | null; /** * type * @default flux_lora_loader_output From 0d0f6a14fa478c8c470f96997c5e26925eb0735b Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 20 Sep 2024 14:56:45 -0400 Subject: [PATCH 08/14] fix(ui): invert trajectory guidance strength value --- .../nodes/util/graph/generation/buildFLUXGraph.ts | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 801d6499cb7..ece20b94d13 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -124,7 +124,7 @@ export const buildFLUXGraph = async ( clip_embed_model: clipEmbedModel, }); - const denoisingStart = 1 - img2imgStrength; + const denoisingValue = 1 - img2imgStrength; if (generationMode === 'txt2img') { canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize); @@ -138,13 +138,13 @@ export const buildFLUXGraph = async ( originalSize, scaledSize, bbox, - denoisingStart, + denoisingValue, false ); if (optimizedDenoisingEnabled) { g.updateNode(noise, { denoising_start: 0, - trajectory_guidance_strength: img2imgStrength, + trajectory_guidance_strength: denoisingValue, }); } } else if (generationMode === 'inpaint') { @@ -158,13 +158,13 @@ export const buildFLUXGraph = async ( modelLoader, originalSize, scaledSize, - denoisingStart, + denoisingValue, false ); if (optimizedDenoisingEnabled) { g.updateNode(noise, { denoising_start: 0, - trajectory_guidance_strength: img2imgStrength, + trajectory_guidance_strength: denoisingValue, }); } } else if (generationMode === 'outpaint') { @@ -178,13 +178,13 @@ export const buildFLUXGraph = async ( modelLoader, originalSize, scaledSize, - denoisingStart, + denoisingValue, false ); if (optimizedDenoisingEnabled) { g.updateNode(noise, { denoising_start: 0, - trajectory_guidance_strength: img2imgStrength, + trajectory_guidance_strength: denoisingValue, }); } } From e50f71ec53ebd35798ca90e3e14d6d084cb6c546 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Fri, 20 Sep 2024 15:41:44 -0400 Subject: [PATCH 09/14] bump version of flux_denoise node, update default workflows --- invokeai/app/invocations/flux_denoise.py | 2 +- .../FLUX Image to Image.json | 187 ++++++++++-------- .../default_workflows/Flux Text to Image.json | 80 +++++--- 3 files changed, 156 insertions(+), 113 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 840a615ce5e..77e02fbcd30 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -43,7 +43,7 @@ title="FLUX Denoise", tags=["image", "flux"], category="image", - version="2.0.0", + version="2.1.0", classification=Classification.Prototype, ) class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json b/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json index ba0b00ddae6..c6b604473b5 100644 --- a/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json +++ b/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json @@ -2,7 +2,7 @@ "name": "FLUX Image to Image", "author": "InvokeAI", "description": "A simple image-to-image workflow using a FLUX dev model. ", - "version": "1.0.4", + "version": "1.1.0", "contact": "", "tags": "image2image, flux, image-to-image", "notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.", @@ -23,67 +23,33 @@ "nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90", "fieldName": "vae_model" }, - { - "nodeId": "ace0258f-67d7-4eee-a218-6fff27065214", - "fieldName": "denoising_start" - }, { "nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", "fieldName": "prompt" }, { - "nodeId": "ace0258f-67d7-4eee-a218-6fff27065214", - "fieldName": "num_steps" + "nodeId": "2981a67c-480f-4237-9384-26b68dbf912b", + "fieldName": "image" } ], "meta": { "version": "3.0.0", - "category": "default" + "category": "user" }, "nodes": [ { - "id": "2981a67c-480f-4237-9384-26b68dbf912b", - "type": "invocation", - "data": { - "id": "2981a67c-480f-4237-9384-26b68dbf912b", - "type": "flux_vae_encode", - "version": "1.0.0", - "label": "", - "notes": "", - "isOpen": true, - "isIntermediate": true, - "useCache": true, - "inputs": { - "image": { - "name": "image", - "label": "", - "value": { - "image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png" - } - }, - "vae": { - "name": "vae", - "label": "" - } - } - }, - "position": { - "x": 732.7680166609682, - "y": -24.37398171806909 - } - }, - { - "id": "ace0258f-67d7-4eee-a218-6fff27065214", + "id": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", "type": "invocation", "data": { - "id": "ace0258f-67d7-4eee-a218-6fff27065214", + "id": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", "type": "flux_denoise", - "version": "1.0.0", + "version": "2.0.0", "label": "", "notes": "", "isOpen": true, "isIntermediate": true, "useCache": true, + "nodePack": "invokeai", "inputs": { "board": { "name": "board", @@ -111,6 +77,11 @@ "label": "", "value": 1 }, + "trajectory_guidance_strength": { + "name": "trajectory_guidance_strength", + "label": "", + "value": 0.0 + }, "transformer": { "name": "transformer", "label": "" @@ -131,7 +102,7 @@ }, "num_steps": { "name": "num_steps", - "label": "Steps (Recommend 30 for Dev, 4 for Schnell)", + "label": "", "value": 30 }, "guidance": { @@ -147,8 +118,36 @@ } }, "position": { - "x": 1182.8836633018684, - "y": -251.38882958913183 + "x": 1159.584057771928, + "y": -175.90561201366845 + } + }, + { + "id": "2981a67c-480f-4237-9384-26b68dbf912b", + "type": "invocation", + "data": { + "id": "2981a67c-480f-4237-9384-26b68dbf912b", + "type": "flux_vae_encode", + "version": "1.0.0", + "label": "", + "notes": "", + "isOpen": true, + "isIntermediate": true, + "useCache": true, + "inputs": { + "image": { + "name": "image", + "label": "" + }, + "vae": { + "name": "vae", + "label": "" + } + } + }, + "position": { + "x": 732.7680166609682, + "y": -24.37398171806909 } }, { @@ -202,18 +201,32 @@ "inputs": { "model": { "name": "model", - "label": "Model (dev variant recommended for Image-to-Image)" + "label": "Model (dev variant recommended for Image-to-Image)", + "value": { + "key": "b4990a6c-0899-48e9-969b-d6f3801acc6a", + "hash": "random:aad8f7bc19ce76541dfb394b62a30f77722542b66e48064a9f25453263b45fba", + "name": "FLUX Dev (Quantized)_2", + "base": "flux", + "type": "main" + } }, "t5_encoder_model": { "name": "t5_encoder_model", - "label": "" + "label": "", + "value": { + "key": "d18d5575-96b6-4da3-b3d8-eb58308d6705", + "hash": "random:f2f9ed74acdfb4bf6fec200e780f6c25f8dd8764a35e65d425d606912fdf573a", + "name": "t5_bnb_int8_quantized_encoder", + "base": "any", + "type": "t5_encoder" + } }, "clip_embed_model": { "name": "clip_embed_model", "label": "", "value": { - "key": "fa23a584-b623-415d-832a-21b5098ff1a1", - "hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70", + "key": "5a19d7e5-8d98-43cd-8a81-87515e4b3b4e", + "hash": "random:4bd08514c08fb6ff04088db9aeb45def3c488e8b5fd09a35f2cc4f2dc346f99f", "name": "clip-vit-large-patch14", "base": "any", "type": "clip_embed" @@ -223,8 +236,8 @@ "name": "vae_model", "label": "", "value": { - "key": "74fc82ba-c0a8-479d-a890-2126f82da758", - "hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068", + "key": "9172beab-5c1d-43f0-b2f0-6e0b956710d9", + "hash": "random:c54dde288e5fa2e6137f1c92e9d611f598049e6f16e360207b6d96c9f5a67ba0", "name": "FLUX.1-schnell_ae", "base": "flux", "type": "vae" @@ -308,68 +321,68 @@ ], "edges": [ { - "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height", + "id": "reactflow__edge-eebd7252-0bd8-401a-bb26-2b8bc64892falatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents", "type": "default", - "source": "2981a67c-480f-4237-9384-26b68dbf912b", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", - "sourceHandle": "height", - "targetHandle": "height" + "source": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", + "target": "7e5172eb-48c1-44db-a770-8fd83e1435d1", + "sourceHandle": "latents", + "targetHandle": "latents" }, { - "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width", + "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-eebd7252-0bd8-401a-bb26-2b8bc64892fatransformer", "type": "default", - "source": "2981a67c-480f-4237-9384-26b68dbf912b", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", - "sourceHandle": "width", - "targetHandle": "width" + "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", + "sourceHandle": "transformer", + "targetHandle": "transformer" + }, + { + "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-eebd7252-0bd8-401a-bb26-2b8bc64892fapositive_text_conditioning", + "type": "default", + "source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", + "sourceHandle": "conditioning", + "targetHandle": "positive_text_conditioning" }, { - "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents", + "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-eebd7252-0bd8-401a-bb26-2b8bc64892falatents", "type": "default", "source": "2981a67c-480f-4237-9384-26b68dbf912b", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", "sourceHandle": "latents", "targetHandle": "latents" }, { - "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae", + "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-eebd7252-0bd8-401a-bb26-2b8bc64892fawidth", "type": "default", - "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90", - "target": "2981a67c-480f-4237-9384-26b68dbf912b", - "sourceHandle": "vae", - "targetHandle": "vae" + "source": "2981a67c-480f-4237-9384-26b68dbf912b", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", + "sourceHandle": "width", + "targetHandle": "width" }, { - "id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents", + "id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-eebd7252-0bd8-401a-bb26-2b8bc64892faheight", "type": "default", - "source": "ace0258f-67d7-4eee-a218-6fff27065214", - "target": "7e5172eb-48c1-44db-a770-8fd83e1435d1", - "sourceHandle": "latents", - "targetHandle": "latents" + "source": "2981a67c-480f-4237-9384-26b68dbf912b", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", + "sourceHandle": "height", + "targetHandle": "height" }, { - "id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed", + "id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-eebd7252-0bd8-401a-bb26-2b8bc64892faseed", "type": "default", "source": "4754c534-a5f3-4ad0-9382-7887985e668c", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", + "target": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", "sourceHandle": "value", "targetHandle": "seed" }, { - "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer", + "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae", "type": "default", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", - "sourceHandle": "transformer", - "targetHandle": "transformer" - }, - { - "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning", - "type": "default", - "source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", - "target": "ace0258f-67d7-4eee-a218-6fff27065214", - "sourceHandle": "conditioning", - "targetHandle": "positive_text_conditioning" + "target": "2981a67c-480f-4237-9384-26b68dbf912b", + "sourceHandle": "vae", + "targetHandle": "vae" }, { "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae", diff --git a/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json index a4bd2ce0632..b30f4deb96c 100644 --- a/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json +++ b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json @@ -2,7 +2,7 @@ "name": "FLUX Text to Image", "author": "InvokeAI", "description": "A simple text-to-image workflow using FLUX dev or schnell models.", - "version": "1.0.4", + "version": "1.0.5", "contact": "", "tags": "text2image, flux", "notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", @@ -26,29 +26,26 @@ { "nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", "fieldName": "prompt" - }, - { - "nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", - "fieldName": "num_steps" } ], "meta": { "version": "3.0.0", - "category": "default" + "category": "user" }, "nodes": [ { - "id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "id": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "type": "invocation", "data": { - "id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "id": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "type": "flux_denoise", - "version": "1.0.0", + "version": "2.1.0", "label": "", "notes": "", "isOpen": true, "isIntermediate": true, "useCache": true, + "nodePack": "invokeai", "inputs": { "board": { "name": "board", @@ -76,6 +73,11 @@ "label": "", "value": 1 }, + "trajectory_guidance_strength": { + "name": "trajectory_guidance_strength", + "label": "", + "value": 0 + }, "transformer": { "name": "transformer", "label": "" @@ -96,8 +98,8 @@ }, "num_steps": { "name": "num_steps", - "label": "Steps (Recommend 30 for Dev, 4 for Schnell)", - "value": 30 + "label": "", + "value": 4 }, "guidance": { "name": "guidance", @@ -112,8 +114,8 @@ } }, "position": { - "x": 1186.1868226120378, - "y": -214.9459927686657 + "x": 1161.0101524413685, + "y": -223.33548695623742 } }, { @@ -167,19 +169,47 @@ "inputs": { "model": { "name": "model", - "label": "" + "label": "", + "value": { + "key": "b4990a6c-0899-48e9-969b-d6f3801acc6a", + "hash": "random:aad8f7bc19ce76541dfb394b62a30f77722542b66e48064a9f25453263b45fba", + "name": "FLUX Dev (Quantized)_2", + "base": "flux", + "type": "main" + } }, "t5_encoder_model": { "name": "t5_encoder_model", - "label": "" + "label": "", + "value": { + "key": "d18d5575-96b6-4da3-b3d8-eb58308d6705", + "hash": "random:f2f9ed74acdfb4bf6fec200e780f6c25f8dd8764a35e65d425d606912fdf573a", + "name": "t5_bnb_int8_quantized_encoder", + "base": "any", + "type": "t5_encoder" + } }, "clip_embed_model": { "name": "clip_embed_model", - "label": "" + "label": "", + "value": { + "key": "5a19d7e5-8d98-43cd-8a81-87515e4b3b4e", + "hash": "random:4bd08514c08fb6ff04088db9aeb45def3c488e8b5fd09a35f2cc4f2dc346f99f", + "name": "clip-vit-large-patch14", + "base": "any", + "type": "clip_embed" + } }, "vae_model": { "name": "vae_model", - "label": "" + "label": "", + "value": { + "key": "9172beab-5c1d-43f0-b2f0-6e0b956710d9", + "hash": "random:c54dde288e5fa2e6137f1c92e9d611f598049e6f16e360207b6d96c9f5a67ba0", + "name": "FLUX.1-schnell_ae", + "base": "flux", + "type": "vae" + } } } }, @@ -259,33 +289,33 @@ ], "edges": [ { - "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer", + "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4ecda92d-ee0e-45ca-aa35-6e9410ac39b9transformer", "type": "default", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90", - "target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "target": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "sourceHandle": "transformer", "targetHandle": "transformer" }, { - "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning", + "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4ecda92d-ee0e-45ca-aa35-6e9410ac39b9positive_text_conditioning", "type": "default", "source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", - "target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "target": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "sourceHandle": "conditioning", "targetHandle": "positive_text_conditioning" }, { - "id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed", + "id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4ecda92d-ee0e-45ca-aa35-6e9410ac39b9seed", "type": "default", "source": "4754c534-a5f3-4ad0-9382-7887985e668c", - "target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "target": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "sourceHandle": "value", "targetHandle": "seed" }, { - "id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents", + "id": "reactflow__edge-4ecda92d-ee0e-45ca-aa35-6e9410ac39b9latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents", "type": "default", - "source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd", + "source": "4ecda92d-ee0e-45ca-aa35-6e9410ac39b9", "target": "7e5172eb-48c1-44db-a770-8fd83e1435d1", "sourceHandle": "latents", "targetHandle": "latents" From 98af514484449ebb80acef333132081b03adb3f9 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 20 Sep 2024 15:45:57 -0400 Subject: [PATCH 10/14] reband to Optimized Inpainting and only apply to inpainting graphs --- invokeai/frontend/web/public/locales/en.json | 6 +++--- .../nodes/util/graph/generation/buildFLUXGraph.ts | 12 ------------ .../Advanced/ParamOptimizedDenoisingToggle.tsx | 2 +- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ba31249801b..e5d09e2d70a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1040,7 +1040,7 @@ "strength": "Strength", "symmetry": "Symmetry", "tileSize": "Tile Size", - "optimizedDenoising": "Optimized Denoising", + "optimizedInpainting": "Optimized Inpainting", "type": "Type", "postProcessing": "Post-Processing (Shift + U)", "processImage": "Process Image", @@ -1542,9 +1542,9 @@ ] }, "optimizedDenoising": { - "heading": "Optimized Denoising", + "heading": "Optimized Inpainting", "paragraphs": [ - "Enable optimized denoising for enhanced image-to-image transformations with Flux models. This setting improves detail and clarity during generation, but may be turned off to preserve more of your original image." + "Enable optimized denoising for enhanced inpainting transformations with Flux models. This setting improves detail and clarity during generation, but may be turned off to preserve more of your original image." ] } }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index ece20b94d13..097e6ff177a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -141,12 +141,6 @@ export const buildFLUXGraph = async ( denoisingValue, false ); - if (optimizedDenoisingEnabled) { - g.updateNode(noise, { - denoising_start: 0, - trajectory_guidance_strength: denoisingValue, - }); - } } else if (generationMode === 'inpaint') { canvasOutput = await addInpaint( state, @@ -181,12 +175,6 @@ export const buildFLUXGraph = async ( denoisingValue, false ); - if (optimizedDenoisingEnabled) { - g.updateNode(noise, { - denoising_start: 0, - trajectory_guidance_strength: denoisingValue, - }); - } } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx index 85b69fff4b4..d7bb3ddfd17 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamOptimizedDenoisingToggle.tsx @@ -25,7 +25,7 @@ export const ParamOptimizedDenoisingToggle = memo(() => { return ( - {t('parameters.optimizedDenoising')} + {t('parameters.optimizedInpainting')} From a4a0cc6d108aa120c86e638b3b953bc0852e6b29 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Fri, 20 Sep 2024 15:48:31 -0400 Subject: [PATCH 11/14] more default workflow updates --- .../default_workflows/FLUX Image to Image.json | 4 ++-- .../default_workflows/Flux Text to Image.json | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json b/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json index c6b604473b5..27fb448ba81 100644 --- a/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json +++ b/invokeai/app/services/workflow_records/default_workflows/FLUX Image to Image.json @@ -34,7 +34,7 @@ ], "meta": { "version": "3.0.0", - "category": "user" + "category": "default" }, "nodes": [ { @@ -43,7 +43,7 @@ "data": { "id": "eebd7252-0bd8-401a-bb26-2b8bc64892fa", "type": "flux_denoise", - "version": "2.0.0", + "version": "2.1.0", "label": "", "notes": "", "isOpen": true, diff --git a/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json index b30f4deb96c..b62332ae3b5 100644 --- a/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json +++ b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json @@ -2,7 +2,7 @@ "name": "FLUX Text to Image", "author": "InvokeAI", "description": "A simple text-to-image workflow using FLUX dev or schnell models.", - "version": "1.0.5", + "version": "1.1.0", "contact": "", "tags": "text2image, flux", "notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", @@ -30,7 +30,7 @@ ], "meta": { "version": "3.0.0", - "category": "user" + "category": "default" }, "nodes": [ { From 16ca540ecedbf857160dc7fab52f7bdeccf3a578 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 20:18:06 +0000 Subject: [PATCH 12/14] Pre-compute trajectory guidance schedule params rather than calculating on each step. --- .../flux/trajectory_guidance_extension.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index 9b779e33382..b6329a1b6bc 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -49,7 +49,6 @@ def __init__( """ assert 0.0 <= trajectory_guidance_strength <= 1.0 self._init_latents = init_latents - self._trajectory_guidance_strength = trajectory_guidance_strength if inpaint_mask is None: # The inpaing mask is None, so we initialize a mask with a single value of 1.0. # This value will be broadcasted and treated as a mask of all 1s. @@ -57,6 +56,13 @@ def __init__( else: self._inpaint_mask = inpaint_mask + # Calculate the params that define the trajectory guidance schedule. + # These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually. + self._trajectory_guidance_strength = trajectory_guidance_strength + self._change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) + self._change_ratio_at_cutoff = 1.0 + self._t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) + def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor: """Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep.""" # As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of @@ -75,24 +81,25 @@ def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor: return mask + def _get_change_ratio(self, t_prev: float) -> float: + """Get the change_ratio for t_prev based on the change schedule.""" + change_ratio = 1.0 + if t_prev > self._t_cutoff: + # If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio + # at the cutoff. + change_ratio = build_line( + x1=1.0, y1=self._change_ratio_at_t_1, x2=self._t_cutoff, y2=self._change_ratio_at_cutoff + )(t_prev) + + return change_ratio + def step( self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float ) -> torch.Tensor: # Handle gradient cutoff. mask = self._apply_mask_gradient_adjustment(t_prev) - # Calculate the change_ratio based on the trajectory_guidance_strength. - # These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually. - change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) - change_ratio_at_cutoff = 1.0 - t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) - change_ratio = 1.0 - if t_prev > t_cutoff: - # If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio - # at the cutoff. - change_ratio = build_line(x1=1.0, y1=change_ratio_at_t_1, x2=t_cutoff, y2=change_ratio_at_cutoff)(t_prev) - - mask = mask * change_ratio + mask = mask * self._get_change_ratio(t_prev) # NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for # the current timestep and then blending the predicted intermediate latents with the noised initial latents. From cd3a7bdb5ee3a37265e7c426bd9fbc88d95c5863 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 20:34:49 +0000 Subject: [PATCH 13/14] Assert that change_ratio is in the expected range in TrajectoryGuidanceExtension. --- invokeai/backend/flux/trajectory_guidance_extension.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index b6329a1b6bc..333dea21944 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -91,6 +91,9 @@ def _get_change_ratio(self, t_prev: float) -> float: x1=1.0, y1=self._change_ratio_at_t_1, x2=self._t_cutoff, y2=self._change_ratio_at_cutoff )(t_prev) + # The change_ratio should be in the range [0, 1]. Assert that we didn't make any mistakes. + eps = 1e-5 + assert 0.0 - eps <= change_ratio <= 1.0 + eps return change_ratio def step( From a43a045b04bf8c01a1108cef6187593adcca201c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 20 Sep 2024 21:08:41 +0000 Subject: [PATCH 14/14] Fix preview image to work well with FLUX trajectory guidance. --- invokeai/backend/flux/denoise.py | 12 ++++++------ .../backend/flux/trajectory_guidance_extension.py | 5 ++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index e837dfc7607..72f87e2aefc 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -38,12 +38,12 @@ def denoise( ) if traj_guidance_extension is not None: - img = traj_guidance_extension.step(t_curr_latents=img, pred_noise=pred, t_curr=t_curr, t_prev=t_prev) - # TODO(ryand): Generate a better preview image. - preview_img = img - else: - preview_img = img - t_curr * pred - img = img + (t_prev - t_curr) * pred + pred = traj_guidance_extension.update_noise( + t_curr_latents=img, pred_noise=pred, t_curr=t_curr, t_prev=t_prev + ) + + preview_img = img - t_curr * pred + img = img + (t_prev - t_curr) * pred step_callback( PipelineIntermediateState( diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index 333dea21944..81bd0db6c63 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -96,7 +96,7 @@ def _get_change_ratio(self, t_prev: float) -> float: assert 0.0 - eps <= change_ratio <= 1.0 + eps return change_ratio - def step( + def update_noise( self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float ) -> torch.Tensor: # Handle gradient cutoff. @@ -131,5 +131,4 @@ def step( # Blend the init_traj_noise with the pred_noise according to the inpaint mask and the trajectory guidance. noise = pred_noise * mask + init_traj_noise * (1.0 - mask) - # Take a denoising step. - return t_curr_latents + (t_prev - t_curr) * noise + return noise