Skip to content

Commit 450a486

Browse files
refactor(ui): make layer adjustments schemas/types composable
1 parent 3380f35 commit 450a486

File tree

4 files changed

+55
-71
lines changed

4 files changed

+55
-71
lines changed

invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@ import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityA
55
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
66
import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice';
77
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
8+
import type { ChannelName, ChannelPoints } from 'features/controlLayers/store/types';
89
import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
910
import { useTranslation } from 'react-i18next';
1011
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
1112

12-
const DEFAULT_POINTS: Array<[number, number]> = [
13+
const DEFAULT_POINTS: ChannelPoints = [
1314
[0, 0],
1415
[255, 255],
1516
];
1617

17-
type Channel = 'master' | 'r' | 'g' | 'b';
18+
type ChannelHistograms = Record<ChannelName, number[] | null>;
1819

19-
type ChannelHistograms = Record<Channel, number[] | null>;
20-
21-
const channelColor: Record<Channel, string> = {
20+
const channelColor: Record<ChannelName, string> = {
2221
master: '#888',
2322
r: '#e53e3e',
2423
g: '#38a169',
@@ -27,7 +26,7 @@ const channelColor: Record<Channel, string> = {
2726

2827
const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v);
2928

30-
const sortPoints = (pts: Array<[number, number]>) =>
29+
const sortPoints = (pts: ChannelPoints) =>
3130
[...pts]
3231
.sort((a, b) => {
3332
const xDiff = a[0] - b[0];
@@ -63,17 +62,17 @@ const CANVAS_STYLE: React.CSSProperties = {
6362

6463
type CurveGraphProps = {
6564
title: string;
66-
channel: Channel;
67-
points: Array<[number, number]> | undefined;
65+
channel: ChannelName;
66+
points: ChannelPoints | undefined;
6867
histogram: number[] | null;
69-
onChange: (pts: Array<[number, number]>) => void;
68+
onChange: (pts: ChannelPoints) => void;
7069
};
7170

7271
const drawHistogram = (
7372
c: HTMLCanvasElement,
74-
channel: Channel,
73+
channel: ChannelName,
7574
histogram: number[] | null,
76-
points: Array<[number, number]>
75+
points: ChannelPoints
7776
) => {
7877
// Use device pixel ratio for crisp rendering on HiDPI displays.
7978
const dpr = window.devicePixelRatio || 1;
@@ -207,7 +206,7 @@ const drawHistogram = (
207206
}
208207
};
209208

210-
const getNearestPointIndex = (c: HTMLCanvasElement, points: Array<[number, number]>, mx: number, my: number) => {
209+
const getNearestPointIndex = (c: HTMLCanvasElement, points: ChannelPoints, mx: number, my: number) => {
211210
const cssWidth = c.clientWidth || CANVAS_WIDTH;
212211
const cssHeight = c.clientHeight || CANVAS_HEIGHT;
213212
const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT;
@@ -249,7 +248,7 @@ const canvasYToValueY = (c: HTMLCanvasElement, cy: number) => {
249248
const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) {
250249
const { title, channel, points, histogram, onChange } = props;
251250
const canvasRef = useRef<HTMLCanvasElement | null>(null);
252-
const [localPoints, setLocalPoints] = useState<Array<[number, number]>>(sortPoints(points ?? DEFAULT_POINTS));
251+
const [localPoints, setLocalPoints] = useState<ChannelPoints>(sortPoints(points ?? DEFAULT_POINTS));
253252
const [dragIndex, setDragIndex] = useState<number | null>(null);
254253

255254
useEffect(() => {
@@ -333,7 +332,7 @@ const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) {
333332
);
334333

335334
const commit = useCallback(
336-
(pts: Array<[number, number]>) => {
335+
(pts: ChannelPoints) => {
337336
onChange(sortPoints(pts));
338337
},
339338
[onChange]
@@ -534,17 +533,17 @@ export const RasterLayerCurvesAdjustmentsEditor = memo(() => {
534533
}, [layer?.objects, layer?.adjustments, recalcHistogram]);
535534

536535
const onChangePoints = useCallback(
537-
(channel: Channel, pts: Array<[number, number]>) => {
536+
(channel: ChannelName, pts: ChannelPoints) => {
538537
dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts }));
539538
},
540539
[dispatch, entityIdentifier]
541540
);
542541

543542
// Memoize per-channel change handlers to avoid inline lambdas in JSX
544-
const onChangeMaster = useCallback((pts: Array<[number, number]>) => onChangePoints('master', pts), [onChangePoints]);
545-
const onChangeR = useCallback((pts: Array<[number, number]>) => onChangePoints('r', pts), [onChangePoints]);
546-
const onChangeG = useCallback((pts: Array<[number, number]>) => onChangePoints('g', pts), [onChangePoints]);
547-
const onChangeB = useCallback((pts: Array<[number, number]>) => onChangePoints('b', pts), [onChangePoints]);
543+
const onChangeMaster = useCallback((pts: ChannelPoints) => onChangePoints('master', pts), [onChangePoints]);
544+
const onChangeR = useCallback((pts: ChannelPoints) => onChangePoints('r', pts), [onChangePoints]);
545+
const onChangeG = useCallback((pts: ChannelPoints) => onChangePoints('g', pts), [onChangePoints]);
546+
const onChangeB = useCallback((pts: ChannelPoints) => onChangePoints('b', pts), [onChangePoints]);
548547

549548
return (
550549
<Flex

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@ import type {
1919
CanvasEntityType,
2020
CanvasInpaintMaskState,
2121
CanvasMetadata,
22+
ChannelName,
23+
ChannelPoints,
2224
ControlLoRAConfig,
2325
EntityMovedByPayload,
2426
FillStyle,
2527
FLUXReduxImageInfluence,
28+
RasterLayerAdjustments,
2629
RegionalGuidanceRefImageState,
2730
RgbColor,
31+
SimpleConfig,
2832
} from 'features/controlLayers/store/types';
2933
import {
3034
calculateNewSize,
@@ -97,7 +101,6 @@ import {
97101
initialIPAdapter,
98102
initialT2IAdapter,
99103
makeDefaultRasterLayerAdjustments,
100-
type RasterLayerAdjustments,
101104
} from './util';
102105

103106
const slice = createSlice({
@@ -108,14 +111,7 @@ const slice = createSlice({
108111
//#region Raster layers
109112
rasterLayerAdjustmentsSet: (
110113
state,
111-
action: PayloadAction<
112-
EntityIdentifierPayload<
113-
{
114-
adjustments: RasterLayerAdjustments | null;
115-
},
116-
'raster_layer'
117-
>
118-
>
114+
action: PayloadAction<EntityIdentifierPayload<{ adjustments: RasterLayerAdjustments | null }, 'raster_layer'>>
119115
) => {
120116
const { entityIdentifier, adjustments } = action.payload;
121117
const layer = selectEntity(state, entityIdentifier);
@@ -152,14 +148,7 @@ const slice = createSlice({
152148
},
153149
rasterLayerAdjustmentsSimpleUpdated: (
154150
state,
155-
action: PayloadAction<
156-
EntityIdentifierPayload<
157-
{
158-
simple: Partial<RasterLayerAdjustments['simple']>;
159-
},
160-
'raster_layer'
161-
>
162-
>
151+
action: PayloadAction<EntityIdentifierPayload<{ simple: Partial<SimpleConfig> }, 'raster_layer'>>
163152
) => {
164153
const { entityIdentifier, simple } = action.payload;
165154
const layer = selectEntity(state, entityIdentifier);
@@ -173,15 +162,7 @@ const slice = createSlice({
173162
},
174163
rasterLayerAdjustmentsCurvesUpdated: (
175164
state,
176-
action: PayloadAction<
177-
EntityIdentifierPayload<
178-
{
179-
channel: 'master' | 'r' | 'g' | 'b';
180-
points: Array<[number, number]>;
181-
},
182-
'raster_layer'
183-
>
184-
>
165+
action: PayloadAction<EntityIdentifierPayload<{ channel: ChannelName; points: ChannelPoints }, 'raster_layer'>>
185166
) => {
186167
const { entityIdentifier, channel, points } = action.payload;
187168
const layer = selectEntity(state, entityIdentifier);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -378,36 +378,41 @@ const zControlLoRAConfig = z.object({
378378
});
379379
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
380380

381+
const zSimpleConfig = z.object({
382+
// All simple params normalized to [-1, 1] except sharpness [0, 1]
383+
brightness: z.number().gte(-1).lte(1),
384+
contrast: z.number().gte(-1).lte(1),
385+
saturation: z.number().gte(-1).lte(1),
386+
temperature: z.number().gte(-1).lte(1),
387+
tint: z.number().gte(-1).lte(1),
388+
sharpness: z.number().gte(0).lte(1),
389+
});
390+
export type SimpleConfig = z.infer<typeof zSimpleConfig>;
391+
392+
const zUint8 = z.number().int().min(0).max(255);
393+
const zChannelPoints = z.array(z.tuple([zUint8, zUint8])).min(2);
394+
const zChannelName = z.enum(['master', 'r', 'g', 'b']);
395+
const zCurvesConfig = z.record(zChannelName, zChannelPoints);
396+
export type ChannelName = z.infer<typeof zChannelName>;
397+
export type ChannelPoints = z.infer<typeof zChannelPoints>;
398+
399+
const zRasterLayerAdjustments = z.object({
400+
version: z.literal(1),
401+
enabled: z.boolean(),
402+
collapsed: z.boolean(),
403+
mode: z.enum(['simple', 'curves']),
404+
simple: zSimpleConfig,
405+
curves: zCurvesConfig,
406+
});
407+
export type RasterLayerAdjustments = z.infer<typeof zRasterLayerAdjustments>;
408+
381409
const zCanvasRasterLayerState = zCanvasEntityBase.extend({
382410
type: z.literal('raster_layer'),
383411
position: zCoordinate,
384412
opacity: zOpacity,
385413
objects: z.array(zCanvasObjectState),
386414
// Optional per-layer color adjustments (simple + curves). When undefined, no adjustments are applied.
387-
adjustments: z
388-
.object({
389-
version: z.literal(1),
390-
enabled: z.boolean(),
391-
collapsed: z.boolean(),
392-
mode: z.enum(['simple', 'curves']),
393-
simple: z.object({
394-
// All simple params normalized to [-1, 1] except sharpness [0, 1]
395-
brightness: z.number().gte(-1).lte(1),
396-
contrast: z.number().gte(-1).lte(1),
397-
saturation: z.number().gte(-1).lte(1),
398-
temperature: z.number().gte(-1).lte(1),
399-
tint: z.number().gte(-1).lte(1),
400-
sharpness: z.number().gte(0).lte(1),
401-
}),
402-
curves: z.object({
403-
// Curves are arrays of [x, y] control points in 0..255 space (no strict monotonic checks here)
404-
master: z.array(z.tuple([z.number().int().gte(0).lte(255), z.number().int().gte(0).lte(255)])).min(2),
405-
r: z.array(z.tuple([z.number().int().gte(0).lte(255), z.number().int().gte(0).lte(255)])).min(2),
406-
g: z.array(z.tuple([z.number().int().gte(0).lte(255), z.number().int().gte(0).lte(255)])).min(2),
407-
b: z.array(z.tuple([z.number().int().gte(0).lte(255), z.number().int().gte(0).lte(255)])).min(2),
408-
}),
409-
})
410-
.optional(),
415+
adjustments: zRasterLayerAdjustments.optional(),
411416
});
412417
export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
413418

invokeai/frontend/web/src/features/controlLayers/store/util.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import type {
1515
Gemini2_5ReferenceImageConfig,
1616
ImageWithDims,
1717
IPAdapterConfig,
18+
RasterLayerAdjustments,
1819
RefImageState,
1920
RgbColor,
2021
T2IAdapterConfig,
@@ -118,8 +119,6 @@ export const initialControlLoRA: ControlLoRAConfig = {
118119
weight: 0.75,
119120
};
120121

121-
export type RasterLayerAdjustments = NonNullable<CanvasRasterLayerState['adjustments']>;
122-
123122
export const makeDefaultRasterLayerAdjustments = (mode: 'simple' | 'curves' = 'simple'): RasterLayerAdjustments => ({
124123
version: 1,
125124
enabled: true,

0 commit comments

Comments
 (0)