refactor(ui): org state in prep for new flow

This commit is contained in:
psychedelicious
2025-05-21 16:26:58 +10:00
parent c9cd0a87be
commit 5f2f12f803
6 changed files with 145 additions and 130 deletions

View File

@@ -0,0 +1,10 @@
import type { z } from 'zod';
/**
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
* @param schema The zod schema to create a type guard from.
* @returns A type guard function for the schema.
*/
export const buildZodTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
};

View File

@@ -70,7 +70,9 @@ import type {
T2IAdapterConfig,
} from './types';
import {
DEFAULT_ASPECT_RATIO_CONFIG,
getEntityIdentifier,
getInitialState,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
@@ -93,55 +95,9 @@ import {
initialT2IAdapter,
} from './util';
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
const getInitialState = (): CanvasState => {
const initialInpaintMaskState = getInpaintMaskState(getPrefixedId('inpaint_mask'));
const initialState: CanvasState = {
_version: 3,
selectedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState),
bookmarkedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState),
rasterLayers: {
isHidden: false,
entities: [],
},
controlLayers: {
isHidden: false,
entities: [],
},
inpaintMasks: {
isHidden: false,
entities: [initialInpaintMaskState],
},
regionalGuidance: {
isHidden: false,
entities: [],
},
referenceImages: { entities: [] },
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: {
id: '1:1',
value: 1,
isLocked: false,
},
scaleMethod: 'auto',
scaledSize: {
width: 512,
height: 512,
},
modelBase: 'sd-1',
},
};
return initialState;
};
const initialState = getInitialState();
export const canvasSlice = createSlice({
name: 'canvas',
initialState,
initialState: getInitialState(),
reducers: {
// undoable canvas state
//#region Raster layers
@@ -1409,7 +1365,7 @@ export const canvasSlice = createSlice({
state.bbox.rect.width = width;
state.bbox.rect.height = height;
} else {
state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio);
state.bbox.aspectRatio = deepClone(DEFAULT_ASPECT_RATIO_CONFIG);
state.bbox.rect.width = optimalDimension;
state.bbox.rect.height = optimalDimension;
}
@@ -1975,7 +1931,7 @@ const migrate = (state: any): any => {
export const canvasPersistConfig: PersistConfig<CanvasState> = {
name: canvasSlice.name,
initialState,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};

View File

@@ -1,3 +1,4 @@
import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
@@ -23,6 +24,9 @@ import { assert } from 'tsafe';
*/
export const selectCanvasSlice = (state: RootState) => state.canvas.present;
export const createCanvasSelector = <T>(selector: Selector<CanvasState, T>) =>
createSelector(selectCanvasSlice, selector);
/**
* Selects the total canvas entity count:
* - Regions
@@ -33,7 +37,7 @@ export const selectCanvasSlice = (state: RootState) => state.canvas.present;
*
* All entities are counted, regardless of their state.
*/
const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => {
const selectEntityCountAll = createCanvasSelector((canvas) => {
return (
canvas.regionalGuidance.entities.length +
canvas.referenceImages.entities.length +
@@ -45,24 +49,29 @@ const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => {
const isVisibleEntity = (entity: CanvasRenderableEntityState) => entity.isEnabled && entity.objects.length > 0;
export const selectActiveRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) =>
canvas.rasterLayers.entities.filter(isVisibleEntity)
export const selectRasterLayerEntities = createCanvasSelector((canvas) => canvas.rasterLayers.entities);
export const selectActiveRasterLayerEntities = createSelector(selectRasterLayerEntities, (entities) =>
entities.filter(isVisibleEntity)
);
export const selectActiveControlLayerEntities = createSelector(selectCanvasSlice, (canvas) =>
canvas.controlLayers.entities.filter(isVisibleEntity)
export const selectControlLayerEntities = createCanvasSelector((canvas) => canvas.controlLayers.entities);
export const selectActiveControlLayerEntities = createSelector(selectControlLayerEntities, (entities) =>
entities.filter(isVisibleEntity)
);
export const selectActiveInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) =>
canvas.inpaintMasks.entities.filter(isVisibleEntity)
export const selectInpaintMaskEntities = createCanvasSelector((canvas) => canvas.inpaintMasks.entities);
export const selectActiveInpaintMaskEntities = createSelector(selectInpaintMaskEntities, (entities) =>
entities.filter(isVisibleEntity)
);
export const selectActiveRegionalGuidanceEntities = createSelector(selectCanvasSlice, (canvas) =>
canvas.regionalGuidance.entities.filter(isVisibleEntity)
export const selectRegionalGuidanceEntities = createCanvasSelector((canvas) => canvas.regionalGuidance.entities);
export const selectActiveRegionalGuidanceEntities = createSelector(selectRegionalGuidanceEntities, (entities) =>
entities.filter(isVisibleEntity)
);
export const selectActiveReferenceImageEntities = createSelector(selectCanvasSlice, (canvas) =>
canvas.referenceImages.entities.filter((e) => e.isEnabled)
export const selectReferenceImageEntities = createCanvasSelector((canvas) => canvas.referenceImages.entities);
export const selectActiveReferenceImageEntities = createSelector(selectReferenceImageEntities, (entities) =>
entities.filter((e) => e.isEnabled)
);
/**
@@ -192,14 +201,6 @@ export function selectEntityIdentifierBelowThisOne<T extends CanvasRenderableEnt
return entity as Extract<CanvasEntityState, T> | undefined;
}
export const selectRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.entities);
export const selectControlLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.entities);
export const selectInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.entities);
export const selectRegionalGuidanceEntities = createSelector(
selectCanvasSlice,
(canvas) => canvas.regionalGuidance.entities
);
/**
* Selected an entity from the canvas slice. If the entity is not found, an error is thrown.
*
@@ -218,7 +219,7 @@ export function selectEntityOrThrow<T extends CanvasEntityIdentifier>(
}
export const selectEntityExists = <T extends CanvasEntityIdentifier>(entityIdentifier: T) => {
return createSelector(selectCanvasSlice, (canvas) => Boolean(selectEntity(canvas, entityIdentifier)));
return createCanvasSelector((canvas) => Boolean(selectEntity(canvas, entityIdentifier)));
};
/**
@@ -299,7 +300,7 @@ export function selectRegionalGuidanceReferenceImage(
return entity.referenceImages.find(({ id }) => id === referenceImageId);
}
export const selectBbox = createSelector(selectCanvasSlice, (canvas) => canvas.bbox);
export const selectBbox = createCanvasSelector((canvas) => canvas.bbox);
export const selectSelectedEntityIdentifier = createSelector(
selectCanvasSlice,
@@ -331,10 +332,10 @@ export const selectSelectedEntityFill = createSelector(
}
);
const selectRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.isHidden);
const selectControlLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.isHidden);
const selectInpaintMasksIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.isHidden);
const selectRegionalGuidanceIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.regionalGuidance.isHidden);
const selectRasterLayersIsHidden = createCanvasSelector((canvas) => canvas.rasterLayers.isHidden);
const selectControlLayersIsHidden = createCanvasSelector((canvas) => canvas.controlLayers.isHidden);
const selectInpaintMasksIsHidden = createCanvasSelector((canvas) => canvas.inpaintMasks.isHidden);
const selectRegionalGuidanceIsHidden = createCanvasSelector((canvas) => canvas.regionalGuidance.isHidden);
/**
* Returns the hidden selector for the given entity type.
@@ -372,7 +373,7 @@ export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier)
* Other entities are considered empty if they have no objects.
*/
export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier) => {
return createSelector(selectCanvasSlice, (canvas) => {
return createCanvasSelector((canvas) => {
const entity = selectEntity(canvas, entityIdentifier);
if (!entity) {
@@ -385,10 +386,10 @@ export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier)
});
};
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);
export const selectAspectRatioValue = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.value);
export const selectWidth = createCanvasSelector((canvas) => canvas.bbox.rect.width);
export const selectHeight = createCanvasSelector((canvas) => canvas.bbox.rect.height);
export const selectAspectRatioID = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.id);
export const selectAspectRatioValue = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.value);
export const selectScaledSize = createSelector(selectBbox, (bbox) => bbox.scaledSize);
export const selectScaleMethod = createSelector(selectBbox, (bbox) => bbox.scaleMethod);
export const selectBboxRect = createSelector(selectBbox, (bbox) => bbox.rect);
@@ -407,3 +408,21 @@ export const selectCanvasMetadata = createSelector(
return { canvas_v2_metadata };
}
);
export const selectIsSessionStarted = createCanvasSelector(({ isSessionStarted }) => isSessionStarted);
export const selectIsCanvasEmpty = createCanvasSelector(
({ controlLayers, inpaintMasks, rasterLayers, regionalGuidance }) => {
// Check it all manually - could use lodash isEqual, but this selector will be called very often!
// Also note - we do not care about ref images, as they are technically not part of canvas
return (
controlLayers.entities.length === 0 &&
controlLayers.isHidden === false &&
inpaintMasks.entities.length === 0 &&
inpaintMasks.isHidden === false &&
rasterLayers.entities.length === 0 &&
rasterLayers.isHidden === false &&
regionalGuidance.entities.length === 0 &&
regionalGuidance.isHidden === false
);
}
);

View File

@@ -1,3 +1,4 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
@@ -438,51 +439,87 @@ export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFlu
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
const zAspectRatioConfig = z.object({
id: zAspectRatioID,
value: z.number().gt(0),
isLocked: z.boolean(),
});
type AspectRatioConfig = z.infer<typeof zAspectRatioConfig>;
export const DEFAULT_ASPECT_RATIO_CONFIG: AspectRatioConfig = {
id: '1:1',
value: 1,
isLocked: false,
};
const zBboxState = z.object({
rect: z.object({
x: z.number().int(),
y: z.number().int(),
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
aspectRatio: zAspectRatioConfig,
scaledSize: z.object({
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
scaleMethod: zBoundingBoxScaleMethod,
modelBase: zMainModelBase,
});
const zCanvasState = z.object({
_version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasInpaintMaskState),
}),
rasterLayers: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRasterLayerState),
}),
controlLayers: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasControlLayerState),
}),
regionalGuidance: z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
}),
referenceImages: z.object({
entities: z.array(zCanvasReferenceImageState),
}),
bbox: z.object({
rect: z.object({
x: z.number().int(),
y: z.number().int(),
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
aspectRatio: z.object({
id: zAspectRatioID,
value: z.number().gt(0),
isLocked: z.boolean(),
}),
scaledSize: z.object({
width: zParameterImageDimension,
height: zParameterImageDimension,
}),
scaleMethod: zBoundingBoxScaleMethod,
modelBase: zMainModelBase,
_version: z.literal(3).default(3),
isSessionStarted: z.boolean().default(false),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
inpaintMasks: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasInpaintMaskState),
})
.default({ isHidden: false, entities: [] }),
rasterLayers: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRasterLayerState),
})
.default({ isHidden: false, entities: [] }),
controlLayers: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasControlLayerState),
})
.default({ isHidden: false, entities: [] }),
regionalGuidance: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
})
.default({ isHidden: false, entities: [] }),
referenceImages: z
.object({
entities: z.array(zCanvasReferenceImageState),
})
.default({ entities: [] }),
bbox: zBboxState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
scaleMethod: 'auto',
scaledSize: {
width: 512,
height: 512,
},
modelBase: 'sd-1',
}),
});
export type CanvasState = z.infer<typeof zCanvasState>;
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
export const getInitialState = () => deepClone(CANVAS_INITIAL_STATE);
export const zCanvasMetadata = z.object({
inpaintMasks: z.array(zCanvasInpaintMaskState),
rasterLayers: z.array(zCanvasRasterLayerState),

View File

@@ -1,11 +1,12 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
export interface DynamicPromptsState {
_version: 1;

View File

@@ -1,5 +1,6 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
import { z } from 'zod';
@@ -15,21 +16,12 @@ import { z } from 'zod';
* simply be the zod schema's safeParse function
*/
/**
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
* @param schema The zod schema to create a type guard from.
* @returns A type guard function for the schema.
*/
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
};
/**
* Helper to create a zod schema and a type guard from it.
* @param schema The zod schema to create a type guard from.
* @returns A tuple containing the zod schema and the type guard function.
*/
const buildParameter = <T extends z.ZodTypeAny>(schema: T) => [schema, buildTypeGuard(schema)] as const;
export const buildParameter = <T extends z.ZodTypeAny>(schema: T) => [schema, buildZodTypeGuard(schema)] as const;
// #region Positive prompt
export const [zParameterPositivePrompt, isParameterPositivePrompt] = buildParameter(z.string());