refactor(ui): params state zodification

This commit is contained in:
psychedelicious
2025-05-22 13:54:57 +10:00
parent a0b0c30be9
commit 02e4a3aa82
6 changed files with 65 additions and 53 deletions

View File

@@ -26,6 +26,7 @@ import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasT
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectIsSessionStarted } from 'features/controlLayers/store/selectors';
import { memo, useCallback } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
@@ -36,7 +37,7 @@ const FOCUS_REGION_STYLES: SystemStyleObject = {
height: 'full',
};
const MenuContent = () => {
const MenuContent = memo(() => {
return (
<CanvasManagerProviderGate>
<MenuList>
@@ -45,9 +46,31 @@ const MenuContent = () => {
</MenuList>
</CanvasManagerProviderGate>
);
};
});
MenuContent.displayName = 'MenuContent';
export const CanvasMainPanelContent = memo(() => {
const isSessionStarted = useAppSelector(selectIsSessionStarted);
if (!isSessionStarted) {
return <CanvasNoSession />;
}
return <CanvasActiveSession />;
});
CanvasMainPanelContent.displayName = 'CanvasMainPanelContent';
const CanvasNoSession = memo(() => {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
FRESH CANVAS is fresh when: - No control layers - No inpaint masks - No regions - No Raster Layers
</Flex>
);
});
CanvasNoSession.displayName = 'CanvasNoSession';
const CanvasActiveSession = memo(() => {
const dynamicGrid = useAppSelector(selectDynamicGrid);
const showHUD = useAppSelector(selectShowHUD);
@@ -134,5 +157,4 @@ export const CanvasMainPanelContent = memo(() => {
</FocusRegionWrapper>
);
});
CanvasMainPanelContent.displayName = 'CanvasMainPanelContent';
CanvasActiveSession.displayName = 'ActiveCanvasContent';

View File

@@ -540,48 +540,40 @@ export type ParamsState = z.infer<typeof zParamsState>;
const INITIAL_PARAMS_STATE = zParamsState.parse({});
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE);
const zInpaintMasks = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasInpaintMaskState),
});
const zRasterLayers = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRasterLayerState),
});
const zControlLayers = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasControlLayerState),
});
const zRegionalGuidance = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
});
const zReferenceImages = z.object({
entities: z.array(zCanvasReferenceImageState),
});
const zCanvasState = z.object({
_version: z.literal(3).default(3),
isSessionStarted: z.boolean().default(false),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
inpaintMasks: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasInpaintMaskState),
})
.default({ isHidden: false, entities: [] }),
rasterLayers: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRasterLayerState),
})
.default({ isHidden: false, entities: [] }),
controlLayers: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasControlLayerState),
})
.default({ isHidden: false, entities: [] }),
regionalGuidance: z
.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
})
.default({ isHidden: false, entities: [] }),
referenceImages: z
.object({
entities: z.array(zCanvasReferenceImageState),
})
.default({ entities: [] }),
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }),
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
referenceImages: zReferenceImages.default({ entities: [] }),
bbox: zBboxState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
scaleMethod: 'auto',
scaledSize: {
width: 512,
height: 512,
},
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1',
}),
});

View File

@@ -1,6 +1,6 @@
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { MainModelBase } from 'features/nodes/types/common';
import type { BaseModelType } from 'features/nodes/types/common';
import {
getGridSize,
getOptimalDimension,
@@ -11,16 +11,16 @@ import {
* Scales the bounding box dimensions to the optimal dimension. The optimal dimensions should be the trained dimension
* for the model. For example, 1024 for SDXL or 512 for SD1.5.
* @param dimensions The un-scaled bbox dimensions
* @param modelBase The base model
* @param base The base model
*/
export const getScaledBoundingBoxDimensions = (dimensions: Dimensions, modelBase: MainModelBase): Dimensions => {
export const getScaledBoundingBoxDimensions = (dimensions: Dimensions, base?: BaseModelType): Dimensions => {
// Special cases: Return original if SDXL and in training dimensions
if (modelBase === 'sdxl' && isInSDXLTrainingDimensions(dimensions.width, dimensions.height)) {
if (base === 'sdxl' && isInSDXLTrainingDimensions(dimensions.width, dimensions.height)) {
return { ...dimensions };
}
const optimalDimension = getOptimalDimension(modelBase);
const gridSize = getGridSize(modelBase);
const optimalDimension = getOptimalDimension(base);
const gridSize = getGridSize(base);
const width = roundToMultiple(dimensions.width, gridSize);
const height = roundToMultiple(dimensions.height, gridSize);
@@ -56,13 +56,13 @@ export const getScaledBoundingBoxDimensions = (dimensions: Dimensions, modelBase
* Calculate the new width and height that will fit the given aspect ratio, retaining the input area
* @param ratio The aspect ratio to calculate the new size for
* @param area The input area
* @param modelBase The base model
* @param base The base model
* @returns The width and height that will fit the given aspect ratio, retaining the input area
*/
export const calculateNewSize = (ratio: number, area: number, modelBase: MainModelBase): Dimensions => {
export const calculateNewSize = (ratio: number, area: number, base?: BaseModelType): Dimensions => {
const exactWidth = Math.sqrt(area * ratio);
const exactHeight = exactWidth / ratio;
const gridSize = getGridSize(modelBase);
const gridSize = getGridSize(base);
return {
width: roundToMultiple(exactWidth, gridSize),

View File

@@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { type ParamsState, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { CanvasState } from 'features/controlLayers/store/types';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
import type { BoardField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { buildPresetModifiedPrompt } from 'features/stylePresets/hooks/usePresetModifiedPrompts';

View File

@@ -1,4 +1,3 @@
import type { MainModelBase } from 'features/nodes/types/common';
import type { BaseModelType } from 'services/api/types';
/**
@@ -117,7 +116,7 @@ export const getIsSizeTooLarge = (width: number, height: number, optimalDimensio
* @param optimalDimension The optimal dimension
* @returns Whether the current width and height needs to be resized to the optimal dimension
*/
export const getIsSizeOptimal = (width: number, height: number, modelBase: MainModelBase): boolean => {
const optimalDimension = getOptimalDimension(modelBase);
export const getIsSizeOptimal = (width: number, height: number, base?: BaseModelType): boolean => {
const optimalDimension = getOptimalDimension(base);
return !getIsSizeTooSmall(width, height, optimalDimension) && !getIsSizeTooLarge(width, height, optimalDimension);
};

View File

@@ -8,10 +8,9 @@ import { useAppSelector } from 'app/store/storeHooks';
import type { AppConfig } from 'app/types/invokeai';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasState } from 'features/controlLayers/store/types';
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
import {
getControlLayerWarnings,
getGlobalReferenceImageWarnings,