mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 23:35:12 -05:00
refactor(ui): org state in prep for new flow
This commit is contained in:
10
invokeai/frontend/web/src/common/util/zodUtils.ts
Normal file
10
invokeai/frontend/web/src/common/util/zodUtils.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import type { z } from 'zod';
|
||||
|
||||
/**
|
||||
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A type guard function for the schema.
|
||||
*/
|
||||
export const buildZodTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
@@ -70,7 +70,9 @@ import type {
|
||||
T2IAdapterConfig,
|
||||
} from './types';
|
||||
import {
|
||||
DEFAULT_ASPECT_RATIO_CONFIG,
|
||||
getEntityIdentifier,
|
||||
getInitialState,
|
||||
isChatGPT4oAspectRatioID,
|
||||
isFluxKontextAspectRatioID,
|
||||
isImagenAspectRatioID,
|
||||
@@ -93,55 +95,9 @@ import {
|
||||
initialT2IAdapter,
|
||||
} from './util';
|
||||
|
||||
/**
|
||||
* Gets a fresh canvas initial state with no references in memory to existing objects.
|
||||
*/
|
||||
const getInitialState = (): CanvasState => {
|
||||
const initialInpaintMaskState = getInpaintMaskState(getPrefixedId('inpaint_mask'));
|
||||
const initialState: CanvasState = {
|
||||
_version: 3,
|
||||
selectedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState),
|
||||
bookmarkedEntityIdentifier: getEntityIdentifier(initialInpaintMaskState),
|
||||
rasterLayers: {
|
||||
isHidden: false,
|
||||
entities: [],
|
||||
},
|
||||
controlLayers: {
|
||||
isHidden: false,
|
||||
entities: [],
|
||||
},
|
||||
inpaintMasks: {
|
||||
isHidden: false,
|
||||
entities: [initialInpaintMaskState],
|
||||
},
|
||||
regionalGuidance: {
|
||||
isHidden: false,
|
||||
entities: [],
|
||||
},
|
||||
referenceImages: { entities: [] },
|
||||
bbox: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: {
|
||||
id: '1:1',
|
||||
value: 1,
|
||||
isLocked: false,
|
||||
},
|
||||
scaleMethod: 'auto',
|
||||
scaledSize: {
|
||||
width: 512,
|
||||
height: 512,
|
||||
},
|
||||
modelBase: 'sd-1',
|
||||
},
|
||||
};
|
||||
return initialState;
|
||||
};
|
||||
|
||||
const initialState = getInitialState();
|
||||
|
||||
export const canvasSlice = createSlice({
|
||||
name: 'canvas',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
// undoable canvas state
|
||||
//#region Raster layers
|
||||
@@ -1409,7 +1365,7 @@ export const canvasSlice = createSlice({
|
||||
state.bbox.rect.width = width;
|
||||
state.bbox.rect.height = height;
|
||||
} else {
|
||||
state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio);
|
||||
state.bbox.aspectRatio = deepClone(DEFAULT_ASPECT_RATIO_CONFIG);
|
||||
state.bbox.rect.width = optimalDimension;
|
||||
state.bbox.rect.height = optimalDimension;
|
||||
}
|
||||
@@ -1975,7 +1931,7 @@ const migrate = (state: any): any => {
|
||||
|
||||
export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
name: canvasSlice.name,
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -23,6 +24,9 @@ import { assert } from 'tsafe';
|
||||
*/
|
||||
export const selectCanvasSlice = (state: RootState) => state.canvas.present;
|
||||
|
||||
export const createCanvasSelector = <T>(selector: Selector<CanvasState, T>) =>
|
||||
createSelector(selectCanvasSlice, selector);
|
||||
|
||||
/**
|
||||
* Selects the total canvas entity count:
|
||||
* - Regions
|
||||
@@ -33,7 +37,7 @@ export const selectCanvasSlice = (state: RootState) => state.canvas.present;
|
||||
*
|
||||
* All entities are counted, regardless of their state.
|
||||
*/
|
||||
const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => {
|
||||
const selectEntityCountAll = createCanvasSelector((canvas) => {
|
||||
return (
|
||||
canvas.regionalGuidance.entities.length +
|
||||
canvas.referenceImages.entities.length +
|
||||
@@ -45,24 +49,29 @@ const selectEntityCountAll = createSelector(selectCanvasSlice, (canvas) => {
|
||||
|
||||
const isVisibleEntity = (entity: CanvasRenderableEntityState) => entity.isEnabled && entity.objects.length > 0;
|
||||
|
||||
export const selectActiveRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) =>
|
||||
canvas.rasterLayers.entities.filter(isVisibleEntity)
|
||||
export const selectRasterLayerEntities = createCanvasSelector((canvas) => canvas.rasterLayers.entities);
|
||||
export const selectActiveRasterLayerEntities = createSelector(selectRasterLayerEntities, (entities) =>
|
||||
entities.filter(isVisibleEntity)
|
||||
);
|
||||
|
||||
export const selectActiveControlLayerEntities = createSelector(selectCanvasSlice, (canvas) =>
|
||||
canvas.controlLayers.entities.filter(isVisibleEntity)
|
||||
export const selectControlLayerEntities = createCanvasSelector((canvas) => canvas.controlLayers.entities);
|
||||
export const selectActiveControlLayerEntities = createSelector(selectControlLayerEntities, (entities) =>
|
||||
entities.filter(isVisibleEntity)
|
||||
);
|
||||
|
||||
export const selectActiveInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) =>
|
||||
canvas.inpaintMasks.entities.filter(isVisibleEntity)
|
||||
export const selectInpaintMaskEntities = createCanvasSelector((canvas) => canvas.inpaintMasks.entities);
|
||||
export const selectActiveInpaintMaskEntities = createSelector(selectInpaintMaskEntities, (entities) =>
|
||||
entities.filter(isVisibleEntity)
|
||||
);
|
||||
|
||||
export const selectActiveRegionalGuidanceEntities = createSelector(selectCanvasSlice, (canvas) =>
|
||||
canvas.regionalGuidance.entities.filter(isVisibleEntity)
|
||||
export const selectRegionalGuidanceEntities = createCanvasSelector((canvas) => canvas.regionalGuidance.entities);
|
||||
export const selectActiveRegionalGuidanceEntities = createSelector(selectRegionalGuidanceEntities, (entities) =>
|
||||
entities.filter(isVisibleEntity)
|
||||
);
|
||||
|
||||
export const selectActiveReferenceImageEntities = createSelector(selectCanvasSlice, (canvas) =>
|
||||
canvas.referenceImages.entities.filter((e) => e.isEnabled)
|
||||
export const selectReferenceImageEntities = createCanvasSelector((canvas) => canvas.referenceImages.entities);
|
||||
export const selectActiveReferenceImageEntities = createSelector(selectReferenceImageEntities, (entities) =>
|
||||
entities.filter((e) => e.isEnabled)
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -192,14 +201,6 @@ export function selectEntityIdentifierBelowThisOne<T extends CanvasRenderableEnt
|
||||
return entity as Extract<CanvasEntityState, T> | undefined;
|
||||
}
|
||||
|
||||
export const selectRasterLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.entities);
|
||||
export const selectControlLayerEntities = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.entities);
|
||||
export const selectInpaintMaskEntities = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.entities);
|
||||
export const selectRegionalGuidanceEntities = createSelector(
|
||||
selectCanvasSlice,
|
||||
(canvas) => canvas.regionalGuidance.entities
|
||||
);
|
||||
|
||||
/**
|
||||
* Selected an entity from the canvas slice. If the entity is not found, an error is thrown.
|
||||
*
|
||||
@@ -218,7 +219,7 @@ export function selectEntityOrThrow<T extends CanvasEntityIdentifier>(
|
||||
}
|
||||
|
||||
export const selectEntityExists = <T extends CanvasEntityIdentifier>(entityIdentifier: T) => {
|
||||
return createSelector(selectCanvasSlice, (canvas) => Boolean(selectEntity(canvas, entityIdentifier)));
|
||||
return createCanvasSelector((canvas) => Boolean(selectEntity(canvas, entityIdentifier)));
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -299,7 +300,7 @@ export function selectRegionalGuidanceReferenceImage(
|
||||
return entity.referenceImages.find(({ id }) => id === referenceImageId);
|
||||
}
|
||||
|
||||
export const selectBbox = createSelector(selectCanvasSlice, (canvas) => canvas.bbox);
|
||||
export const selectBbox = createCanvasSelector((canvas) => canvas.bbox);
|
||||
|
||||
export const selectSelectedEntityIdentifier = createSelector(
|
||||
selectCanvasSlice,
|
||||
@@ -331,10 +332,10 @@ export const selectSelectedEntityFill = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
const selectRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.rasterLayers.isHidden);
|
||||
const selectControlLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.controlLayers.isHidden);
|
||||
const selectInpaintMasksIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.inpaintMasks.isHidden);
|
||||
const selectRegionalGuidanceIsHidden = createSelector(selectCanvasSlice, (canvas) => canvas.regionalGuidance.isHidden);
|
||||
const selectRasterLayersIsHidden = createCanvasSelector((canvas) => canvas.rasterLayers.isHidden);
|
||||
const selectControlLayersIsHidden = createCanvasSelector((canvas) => canvas.controlLayers.isHidden);
|
||||
const selectInpaintMasksIsHidden = createCanvasSelector((canvas) => canvas.inpaintMasks.isHidden);
|
||||
const selectRegionalGuidanceIsHidden = createCanvasSelector((canvas) => canvas.regionalGuidance.isHidden);
|
||||
|
||||
/**
|
||||
* Returns the hidden selector for the given entity type.
|
||||
@@ -372,7 +373,7 @@ export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier)
|
||||
* Other entities are considered empty if they have no objects.
|
||||
*/
|
||||
export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
return createSelector(selectCanvasSlice, (canvas) => {
|
||||
return createCanvasSelector((canvas) => {
|
||||
const entity = selectEntity(canvas, entityIdentifier);
|
||||
|
||||
if (!entity) {
|
||||
@@ -385,10 +386,10 @@ export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier)
|
||||
});
|
||||
};
|
||||
|
||||
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
|
||||
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
|
||||
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);
|
||||
export const selectAspectRatioValue = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.value);
|
||||
export const selectWidth = createCanvasSelector((canvas) => canvas.bbox.rect.width);
|
||||
export const selectHeight = createCanvasSelector((canvas) => canvas.bbox.rect.height);
|
||||
export const selectAspectRatioID = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.id);
|
||||
export const selectAspectRatioValue = createCanvasSelector((canvas) => canvas.bbox.aspectRatio.value);
|
||||
export const selectScaledSize = createSelector(selectBbox, (bbox) => bbox.scaledSize);
|
||||
export const selectScaleMethod = createSelector(selectBbox, (bbox) => bbox.scaleMethod);
|
||||
export const selectBboxRect = createSelector(selectBbox, (bbox) => bbox.rect);
|
||||
@@ -407,3 +408,21 @@ export const selectCanvasMetadata = createSelector(
|
||||
return { canvas_v2_metadata };
|
||||
}
|
||||
);
|
||||
|
||||
export const selectIsSessionStarted = createCanvasSelector(({ isSessionStarted }) => isSessionStarted);
|
||||
export const selectIsCanvasEmpty = createCanvasSelector(
|
||||
({ controlLayers, inpaintMasks, rasterLayers, regionalGuidance }) => {
|
||||
// Check it all manually - could use lodash isEqual, but this selector will be called very often!
|
||||
// Also note - we do not care about ref images, as they are technically not part of canvas
|
||||
return (
|
||||
controlLayers.entities.length === 0 &&
|
||||
controlLayers.isHidden === false &&
|
||||
inpaintMasks.entities.length === 0 &&
|
||||
inpaintMasks.isHidden === false &&
|
||||
rasterLayers.entities.length === 0 &&
|
||||
rasterLayers.isHidden === false &&
|
||||
regionalGuidance.entities.length === 0 &&
|
||||
regionalGuidance.isHidden === false
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
@@ -438,51 +439,87 @@ export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFlu
|
||||
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
|
||||
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
|
||||
const zAspectRatioConfig = z.object({
|
||||
id: zAspectRatioID,
|
||||
value: z.number().gt(0),
|
||||
isLocked: z.boolean(),
|
||||
});
|
||||
type AspectRatioConfig = z.infer<typeof zAspectRatioConfig>;
|
||||
|
||||
export const DEFAULT_ASPECT_RATIO_CONFIG: AspectRatioConfig = {
|
||||
id: '1:1',
|
||||
value: 1,
|
||||
isLocked: false,
|
||||
};
|
||||
|
||||
const zBboxState = z.object({
|
||||
rect: z.object({
|
||||
x: z.number().int(),
|
||||
y: z.number().int(),
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
aspectRatio: zAspectRatioConfig,
|
||||
scaledSize: z.object({
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
scaleMethod: zBoundingBoxScaleMethod,
|
||||
modelBase: zMainModelBase,
|
||||
});
|
||||
const zCanvasState = z.object({
|
||||
_version: z.literal(3),
|
||||
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
|
||||
inpaintMasks: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasInpaintMaskState),
|
||||
}),
|
||||
rasterLayers: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRasterLayerState),
|
||||
}),
|
||||
controlLayers: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasControlLayerState),
|
||||
}),
|
||||
regionalGuidance: z.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRegionalGuidanceState),
|
||||
}),
|
||||
referenceImages: z.object({
|
||||
entities: z.array(zCanvasReferenceImageState),
|
||||
}),
|
||||
bbox: z.object({
|
||||
rect: z.object({
|
||||
x: z.number().int(),
|
||||
y: z.number().int(),
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
aspectRatio: z.object({
|
||||
id: zAspectRatioID,
|
||||
value: z.number().gt(0),
|
||||
isLocked: z.boolean(),
|
||||
}),
|
||||
scaledSize: z.object({
|
||||
width: zParameterImageDimension,
|
||||
height: zParameterImageDimension,
|
||||
}),
|
||||
scaleMethod: zBoundingBoxScaleMethod,
|
||||
modelBase: zMainModelBase,
|
||||
_version: z.literal(3).default(3),
|
||||
isSessionStarted: z.boolean().default(false),
|
||||
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
|
||||
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
|
||||
inpaintMasks: z
|
||||
.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasInpaintMaskState),
|
||||
})
|
||||
.default({ isHidden: false, entities: [] }),
|
||||
rasterLayers: z
|
||||
.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRasterLayerState),
|
||||
})
|
||||
.default({ isHidden: false, entities: [] }),
|
||||
controlLayers: z
|
||||
.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasControlLayerState),
|
||||
})
|
||||
.default({ isHidden: false, entities: [] }),
|
||||
regionalGuidance: z
|
||||
.object({
|
||||
isHidden: z.boolean(),
|
||||
entities: z.array(zCanvasRegionalGuidanceState),
|
||||
})
|
||||
.default({ isHidden: false, entities: [] }),
|
||||
referenceImages: z
|
||||
.object({
|
||||
entities: z.array(zCanvasReferenceImageState),
|
||||
})
|
||||
.default({ entities: [] }),
|
||||
bbox: zBboxState.default({
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
|
||||
scaleMethod: 'auto',
|
||||
scaledSize: {
|
||||
width: 512,
|
||||
height: 512,
|
||||
},
|
||||
modelBase: 'sd-1',
|
||||
}),
|
||||
});
|
||||
export type CanvasState = z.infer<typeof zCanvasState>;
|
||||
|
||||
/**
|
||||
* Gets a fresh canvas initial state with no references in memory to existing objects.
|
||||
*/
|
||||
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
|
||||
export const getInitialState = () => deepClone(CANVAS_INITIAL_STATE);
|
||||
|
||||
export const zCanvasMetadata = z.object({
|
||||
inpaintMasks: z.array(zCanvasInpaintMaskState),
|
||||
rasterLayers: z.array(zCanvasRasterLayerState),
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { buildZodTypeGuard } from 'common/util/zodUtils';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
|
||||
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
|
||||
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
|
||||
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
|
||||
|
||||
export interface DynamicPromptsState {
|
||||
_version: 1;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { buildZodTypeGuard } from 'common/util/zodUtils';
|
||||
import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -15,21 +16,12 @@ import { z } from 'zod';
|
||||
* simply be the zod schema's safeParse function
|
||||
*/
|
||||
|
||||
/**
|
||||
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A type guard function for the schema.
|
||||
*/
|
||||
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper to create a zod schema and a type guard from it.
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A tuple containing the zod schema and the type guard function.
|
||||
*/
|
||||
const buildParameter = <T extends z.ZodTypeAny>(schema: T) => [schema, buildTypeGuard(schema)] as const;
|
||||
export const buildParameter = <T extends z.ZodTypeAny>(schema: T) => [schema, buildZodTypeGuard(schema)] as const;
|
||||
|
||||
// #region Positive prompt
|
||||
export const [zParameterPositivePrompt, isParameterPositivePrompt] = buildParameter(z.string());
|
||||
|
||||
Reference in New Issue
Block a user