Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useReadinessWatcher } from 'features/queue/store/readiness';
Expand Down Expand Up @@ -58,6 +59,7 @@ export const GlobalHookIsolator = memo(
useSyncQueueStatus();
useFocusRegionWatcher();
useWorkflowBuilderWatcher();
useDynamicPromptsWatcher();

return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import { addImageToDeleteSelectedListener } from 'app/store/middleware/listenerM
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
Expand Down Expand Up @@ -95,7 +94,4 @@ addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);

// Prompts
addDynamicPromptsListener(startAppListening);

addSetDefaultSettingsListener(startAppListening);

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import {
isErrorChanged,
isLoadingChanged,
parsingErrorChanged,
promptsChanged,
selectDynamicPromptsMaxPrompts,
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { debounce } from 'lodash-es';
import { useEffect, useMemo } from 'react';
import { utilitiesApi } from 'services/api/endpoints/utilities';

const DYNAMIC_PROMPTS_DEBOUNCE_MS = 1000;

/**
* This hook watches for changes to state that should trigger dynamic prompts to be updated.
*/
export const useDynamicPromptsWatcher = () => {
const { getState, dispatch } = useAppStore();
// The prompt to process is derived from the preset-modified prompts
const presetModifiedPrompts = useAppSelector(selectPresetModifiedPrompts);
const maxPrompts = useAppSelector(selectDynamicPromptsMaxPrompts);

const dynamicPrompting = useFeatureStatus('dynamicPrompting');

const debouncedUpdateDynamicPrompts = useMemo(
() =>
debounce(async (positivePrompt: string, maxPrompts: number) => {
// Try to fetch the dynamic prompts and store in state
try {
const req = dispatch(
utilitiesApi.endpoints.dynamicPrompts.initiate(
{
prompt: positivePrompt,
max_prompts: maxPrompts,
},
{ subscribe: false }
)
);

const res = await req.unwrap();

dispatch(promptsChanged(res.prompts));
dispatch(parsingErrorChanged(res.error));
dispatch(isErrorChanged(false));
} catch {
dispatch(isErrorChanged(true));
dispatch(isLoadingChanged(false));
}
}, DYNAMIC_PROMPTS_DEBOUNCE_MS),
[dispatch]
);

useEffect(() => {
if (!dynamicPrompting) {
return;
}

const { positivePrompt } = presetModifiedPrompts;

// Before we execute, imperatively check the dynamic prompts query cache to see if we have already fetched this prompt
const state = getState();

const cachedPrompts = utilitiesApi.endpoints.dynamicPrompts.select({
prompt: positivePrompt,
max_prompts: maxPrompts,
})(state).data;

if (cachedPrompts) {
// Yep we already did this prompt, use the cached result
dispatch(promptsChanged(cachedPrompts.prompts));
dispatch(parsingErrorChanged(cachedPrompts.error));
return;
}

// If the prompt is not in the cache, check if we should process it - this is just looking for dynamic prompts syntax
if (!getShouldProcessPrompt(positivePrompt)) {
dispatch(promptsChanged([positivePrompt]));
dispatch(parsingErrorChanged(undefined));
dispatch(isErrorChanged(false));
return;
}

// If we are here, we need to process the prompt
if (!state.dynamicPrompts.isLoading) {
dispatch(isLoadingChanged(true));
}

debouncedUpdateDynamicPrompts(positivePrompt, maxPrompts);
}, [debouncedUpdateDynamicPrompts, dispatch, dynamicPrompting, getState, maxPrompts, presetModifiedPrompts]);
};
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ export const dynamicPromptsSlice = createSlice({
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
promptsChanged: (state, action: PayloadAction<string[]>) => {
state.prompts = action.payload;
state.isLoading = false;
Expand All @@ -63,8 +57,6 @@ export const dynamicPromptsSlice = createSlice({

export const {
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
promptsChanged,
parsingErrorChanged,
isErrorChanged,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from '
import { assert } from 'tsafe';

import { addLoRAs } from './generation/addLoRAs';
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
import { getBoardField, selectPresetModifiedPrompts } from './graphBuilderUtils';

export const buildMultidiffusionUpscaleGraph = async (
state: RootState
Expand Down Expand Up @@ -97,7 +97,7 @@ export const buildMultidiffusionUpscaleGraph = async (

if (model.base === 'sdxl') {
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
getPresetModifiedPrompts(state);
selectPresetModifiedPrompts(state);

posCond = g.addNode({
type: 'sdxl_compel_prompt',
Expand Down Expand Up @@ -130,7 +130,7 @@ export const buildMultidiffusionUpscaleGraph = async (
negative_style_prompt: negativeStylePrompt,
});
} else {
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);

posCond = g.addNode({
type: 'compel',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
Expand Down Expand Up @@ -45,7 +45,7 @@ export const buildCogView4Graph = async (
assert(model, 'No model found in state');

const { originalSize, scaledSize } = getSizes(bbox);
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);

const g = new Graph(getPrefixedId('cogview4_graph'));
const modelLoader = g.addNode({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
Expand Down Expand Up @@ -91,7 +91,7 @@ export const buildFLUXGraph = async (
guidance = 30;
}

const { positivePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt } = selectPresetModifiedPrompts(state);

const g = new Graph(getPrefixedId('flux_graph'));
const modelLoader = g.addNode({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
Expand Down Expand Up @@ -62,7 +62,7 @@ export const buildSD1Graph = async (
assert(model, 'No model found in state');

const fp32 = vaePrecision === 'fp32';
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const { originalSize, scaledSize } = getSizes(bbox);

const g = new Graph(getPrefixedId('sd1_graph'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
Expand Down Expand Up @@ -56,7 +56,7 @@ export const buildSD3Graph = async (
} = params;

const { originalSize, scaledSize } = getSizes(bbox);
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);

const g = new Graph(getPrefixedId('sd3_graph'));
const modelLoader = g.addNode({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
Expand Down Expand Up @@ -67,7 +67,8 @@ export const buildSDXLGraph = async (

const fp32 = vaePrecision === 'fp32';
const { originalSize, scaledSize } = getSizes(bbox);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
selectPresetModifiedPrompts(state);

const g = new Graph(getPrefixedId('sdxl_graph'));
const modelLoader = g.addNode({
Expand Down
Loading