mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-16 19:05:03 -05:00
Merge branch 'main' into nodepromptsize
This commit is contained in:
@@ -102,8 +102,7 @@
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt",
|
||||
"clearNodes": "Are you sure you want to clear all nodes?"
|
||||
"imagePrompt": "Image Prompt"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@@ -615,6 +614,11 @@
|
||||
"initialImageNotSetDesc": "Could not load initial image",
|
||||
"nodesSaved": "Nodes Saved",
|
||||
"nodesLoaded": "Nodes Loaded",
|
||||
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
|
||||
"nodesNotValidJSON": "Not a valid JSON",
|
||||
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
|
||||
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
|
||||
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
|
||||
"nodesLoadedFailed": "Failed To Load Nodes",
|
||||
"nodesCleared": "Nodes Cleared"
|
||||
},
|
||||
@@ -700,9 +704,10 @@
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes",
|
||||
"saveGraph": "Save Graph",
|
||||
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
||||
"clearGraph": "Clear Graph",
|
||||
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
export const NUMPY_RAND_MAX = 2147483647;
|
||||
export const NUMPY_RAND_MAX = 4294967295;
|
||||
|
||||
@@ -65,11 +65,14 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
|
||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
@@ -153,6 +156,8 @@ addSocketDisconnectedListener();
|
||||
addSocketSubscribedListener();
|
||||
addSocketUnsubscribedListener();
|
||||
addModelLoadEventListener();
|
||||
addSessionRetrievalErrorEventListener();
|
||||
addInvocationRetrievalErrorEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
@@ -197,3 +202,6 @@ addFirstListImagesListener();
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener();
|
||||
|
||||
// Tab Change
|
||||
addTabChangedListener();
|
||||
|
||||
@@ -9,13 +9,19 @@ import {
|
||||
zMainModel,
|
||||
zVaeModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setShouldUseSDXLRefiner,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
@@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
|
||||
dispatch(modelChanged(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info(
|
||||
{ models: action.payload.entities },
|
||||
`SDXL Refiner models loaded (${action.payload.ids.length})`
|
||||
);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
|
||||
const isCurrentModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
|
||||
if (!firstModel) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
dispatch(setShouldUseSDXLRefiner(false));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{ error: result.error.format() },
|
||||
'Failed to parse SDXL Refiner Model'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(refinerModelChanged(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
|
||||
@@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
const { error, status } = action.payload;
|
||||
const graph = parseify(action.meta.arg);
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
log.error(
|
||||
{ graph, error: serializeError(error) },
|
||||
`Problem creating session: ${stringifiedError}`
|
||||
{ graph, status, error: serializeError(error) },
|
||||
`Problem creating session`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem invoking session: ${stringifiedError}`
|
||||
`Problem invoking session`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -3,6 +3,11 @@ import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import {
|
||||
ALL_BASE_MODELS,
|
||||
NON_REFINER_BASE_MODELS,
|
||||
REFINER_BASE_MODELS,
|
||||
} from 'services/api/constants';
|
||||
|
||||
export const addSocketConnectedEventListener = () => {
|
||||
startAppListening({
|
||||
@@ -24,7 +29,11 @@ export const addSocketConnectedEventListener = () => {
|
||||
dispatch(appSocketConnected(action.payload));
|
||||
|
||||
// update all server state
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate(REFINER_BASE_MODELS));
|
||||
dispatch(
|
||||
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
|
||||
);
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate(ALL_BASE_MODELS));
|
||||
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
appSocketInvocationRetrievalError,
|
||||
socketInvocationRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
export const addInvocationRetrievalErrorEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationRetrievalError,
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('socketio');
|
||||
log.error(
|
||||
action.payload,
|
||||
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
|
||||
);
|
||||
dispatch(appSocketInvocationRetrievalError(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
log.debug(action.payload, 'Invocation started');
|
||||
log.debug(
|
||||
action.payload,
|
||||
`Invocation started (${action.payload.data.node.type})`
|
||||
);
|
||||
dispatch(appSocketInvocationStarted(action.payload));
|
||||
},
|
||||
});
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
appSocketSessionRetrievalError,
|
||||
socketSessionRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
export const addSessionRetrievalErrorEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketSessionRetrievalError,
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('socketio');
|
||||
log.error(
|
||||
action.payload,
|
||||
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
|
||||
);
|
||||
dispatch(appSocketSessionRetrievalError(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,56 @@
|
||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
modelsApi,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addTabChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: setActiveTab,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const activeTabName = action.payload;
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
// grab the models from RTK Query cache
|
||||
const { data } = modelsApi.endpoints.getMainModels.select(
|
||||
NON_REFINER_BASE_MODELS
|
||||
)(getState());
|
||||
|
||||
if (!data) {
|
||||
// no models yet, so we can't do anything
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// need to filter out all the invalid canvas models (currently, this is just sdxl)
|
||||
const validCanvasModels: MainModelConfigEntity[] = [];
|
||||
|
||||
forEach(data.entities, (entity) => {
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
|
||||
validCanvasModels.push(entity);
|
||||
}
|
||||
});
|
||||
|
||||
// this could still be undefined even tho TS doesn't say so
|
||||
const firstValidCanvasModel = validCanvasModels[0];
|
||||
|
||||
if (!firstValidCanvasModel) {
|
||||
// uh oh, we have no models that are valid for canvas
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// only store the model name and base model in redux
|
||||
const { base_model, model_name } = firstValidCanvasModel;
|
||||
|
||||
dispatch(modelChanged({ base_model, model_name }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => {
|
||||
|
||||
const state = getState();
|
||||
|
||||
const {
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea,
|
||||
} = state.canvas;
|
||||
|
||||
// Build canvas blobs
|
||||
const canvasBlobsAndImageData = await getCanvasData(state);
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
log.error('Unable to create canvas data');
|
||||
|
||||
@@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
@@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const log = logger('session');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
|
||||
let graph;
|
||||
|
||||
if (model && model.base_model === 'sdxl') {
|
||||
graph = buildLinearSDXLImageToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearImageToImageGraph(state);
|
||||
}
|
||||
|
||||
const graph = buildLinearImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
@@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const log = logger('session');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
|
||||
const graph = buildLinearTextToImageGraph(state);
|
||||
let graph;
|
||||
|
||||
if (model && model.base_model === 'sdxl') {
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearTextToImageGraph(state);
|
||||
}
|
||||
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||
@@ -47,6 +48,7 @@ const allReducers = {
|
||||
imageDeletion: imageDeletionReducer,
|
||||
lora: loraReducer,
|
||||
modelmanager: modelmanagerReducer,
|
||||
sdxl: sdxlReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'canvas',
|
||||
'gallery',
|
||||
'generation',
|
||||
'sdxl',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
|
||||
@@ -95,7 +95,8 @@ export type AppFeature =
|
||||
| 'localization'
|
||||
| 'consoleLogging'
|
||||
| 'dynamicPrompting'
|
||||
| 'batches';
|
||||
| 'batches'
|
||||
| 'syncModels';
|
||||
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
|
||||
@@ -6,6 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@@ -24,7 +25,7 @@ const readinessSelector = createSelector(
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
modelsApi.endpoints.getMainModels.select(ALL_BASE_MODELS)(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import GenerationModeStatusText from 'features/parameters/components/Parameters/Canvas/GenerationModeStatusText';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import roundToHundreth from '../util/roundToHundreth';
|
||||
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
|
||||
@@ -110,6 +110,7 @@ const IAICanvasStatusText = () => {
|
||||
},
|
||||
}}
|
||||
>
|
||||
<GenerationModeStatusText />
|
||||
<Box
|
||||
style={{
|
||||
color: activeLayerColor,
|
||||
|
||||
@@ -2,7 +2,15 @@ import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
|
||||
import {
|
||||
canvasCopiedToClipboard,
|
||||
canvasDownloadedAsImage,
|
||||
canvasMerged,
|
||||
canvasSavedToGallery,
|
||||
} from 'features/canvas/store/actions';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
@@ -21,16 +29,8 @@ import {
|
||||
} from 'features/canvas/store/canvasTypes';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import {
|
||||
canvasCopiedToClipboard,
|
||||
canvasDownloadedAsImage,
|
||||
canvasMerged,
|
||||
canvasSavedToGallery,
|
||||
} from 'features/canvas/store/actions';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -48,7 +48,6 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
|
||||
export const selector = createSelector(
|
||||
[systemSelector, canvasSelector, isStagingSelector],
|
||||
@@ -220,7 +219,7 @@ const IAICanvasToolbar = () => {
|
||||
}}
|
||||
>
|
||||
<Box w={24}>
|
||||
<IAIMantineSearchableSelect
|
||||
<IAIMantineSelect
|
||||
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
|
||||
value={layer}
|
||||
data={LAYER_NAMES_DICT}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { GenerationMode } from 'features/canvas/store/canvasTypes';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useDebounce } from 'react-use';
|
||||
|
||||
export const useCanvasGenerationMode = () => {
|
||||
const layerState = useAppSelector((state) => state.canvas.layerState);
|
||||
|
||||
const boundingBoxCoordinates = useAppSelector(
|
||||
(state) => state.canvas.boundingBoxCoordinates
|
||||
);
|
||||
const boundingBoxDimensions = useAppSelector(
|
||||
(state) => state.canvas.boundingBoxDimensions
|
||||
);
|
||||
const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled);
|
||||
|
||||
const shouldPreserveMaskedArea = useAppSelector(
|
||||
(state) => state.canvas.shouldPreserveMaskedArea
|
||||
);
|
||||
const [generationMode, setGenerationMode] = useState<
|
||||
GenerationMode | undefined
|
||||
>();
|
||||
|
||||
useEffect(() => {
|
||||
setGenerationMode(undefined);
|
||||
}, [
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea,
|
||||
]);
|
||||
|
||||
useDebounce(
|
||||
async () => {
|
||||
// Build canvas blobs
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
|
||||
|
||||
// Determine the generation mode
|
||||
const generationMode = getCanvasGenerationMode(
|
||||
baseImageData,
|
||||
maskImageData
|
||||
);
|
||||
|
||||
setGenerationMode(generationMode);
|
||||
},
|
||||
1000,
|
||||
[
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea,
|
||||
]
|
||||
);
|
||||
|
||||
return generationMode;
|
||||
};
|
||||
@@ -168,4 +168,7 @@ export interface CanvasState {
|
||||
stageDimensions: Dimensions;
|
||||
stageScale: number;
|
||||
tool: CanvasTool;
|
||||
generationMode?: GenerationMode;
|
||||
}
|
||||
|
||||
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||
import { Vector2d } from 'konva/lib/types';
|
||||
import {
|
||||
CanvasLayerState,
|
||||
Dimensions,
|
||||
isCanvasMaskLine,
|
||||
} from '../store/canvasTypes';
|
||||
import createMaskStage from './createMaskStage';
|
||||
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
|
||||
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||
@@ -9,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData';
|
||||
/**
|
||||
* Gets Blob and ImageData objects for the base and mask layers
|
||||
*/
|
||||
export const getCanvasData = async (state: RootState) => {
|
||||
export const getCanvasData = async (
|
||||
layerState: CanvasLayerState,
|
||||
boundingBoxCoordinates: Vector2d,
|
||||
boundingBoxDimensions: Dimensions,
|
||||
isMaskEnabled: boolean,
|
||||
shouldPreserveMaskedArea: boolean
|
||||
) => {
|
||||
const log = logger('canvas');
|
||||
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
@@ -20,14 +30,6 @@ export const getCanvasData = async (state: RootState) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
layerState: { objects },
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea,
|
||||
} = state.canvas;
|
||||
|
||||
const boundingBox = {
|
||||
...boundingBoxCoordinates,
|
||||
...boundingBoxDimensions,
|
||||
@@ -58,7 +60,7 @@ export const getCanvasData = async (state: RootState) => {
|
||||
|
||||
// For the mask layer, use the normal boundingBox
|
||||
const maskStage = await createMaskStage(
|
||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
||||
isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
||||
boundingBox,
|
||||
shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
@@ -2,11 +2,12 @@ import {
|
||||
areAnyPixelsBlack,
|
||||
getImageDataTransparency,
|
||||
} from 'common/util/arrayBuffer';
|
||||
import { GenerationMode } from '../store/canvasTypes';
|
||||
|
||||
export const getCanvasGenerationMode = (
|
||||
baseImageData: ImageData,
|
||||
maskImageData: ImageData
|
||||
) => {
|
||||
): GenerationMode => {
|
||||
const {
|
||||
isPartiallyTransparent: baseIsPartiallyTransparent,
|
||||
isFullyTransparent: baseIsFullyTransparent,
|
||||
|
||||
@@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||
return (
|
||||
<RefinerModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
|
||||
@@ -14,8 +14,10 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<MainModelInputFieldValue, ModelInputFieldTemplate>
|
||||
@@ -24,8 +26,11 @@ const ModelInputFieldComponent = (
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
@@ -103,9 +108,11 @@ const ModelInputFieldComponent = (
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
RefinerModelInputFieldTemplate,
|
||||
RefinerModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
RefinerModelInputFieldValue,
|
||||
RefinerModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
refinerModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelInputFieldComponent);
|
||||
@@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react';
|
||||
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import LoadNodesButton from '../ui/LoadNodesButton';
|
||||
import ClearGraphButton from '../ui/ClearGraphButton';
|
||||
import LoadGraphButton from '../ui/LoadGraphButton';
|
||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
||||
import SaveNodesButton from '../ui/SaveNodesButton';
|
||||
import ClearNodesButton from '../ui/ClearNodesButton';
|
||||
import SaveGraphButton from '../ui/SaveGraphButton';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
return (
|
||||
@@ -15,9 +15,9 @@ const TopCenterPanel = () => {
|
||||
<NodeInvokeButton />
|
||||
<CancelButton />
|
||||
<ReloadSchemaButton />
|
||||
<SaveNodesButton />
|
||||
<LoadNodesButton />
|
||||
<ClearNodesButton />
|
||||
<SaveGraphButton />
|
||||
<LoadGraphButton />
|
||||
<ClearGraphButton />
|
||||
</HStack>
|
||||
</Panel>
|
||||
);
|
||||
|
||||
@@ -9,17 +9,17 @@ import {
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useRef, useCallback } from 'react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
|
||||
const ClearNodesButton = () => {
|
||||
const ClearGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
@@ -46,8 +46,8 @@ const ClearNodesButton = () => {
|
||||
<>
|
||||
<IAIIconButton
|
||||
icon={<FaTrash />}
|
||||
tooltip={t('nodes.clearNodes')}
|
||||
aria-label={t('nodes.clearNodes')}
|
||||
tooltip={t('nodes.clearGraph')}
|
||||
aria-label={t('nodes.clearGraph')}
|
||||
onClick={onOpen}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
@@ -62,11 +62,11 @@ const ClearNodesButton = () => {
|
||||
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{t('nodes.clearNodes')}
|
||||
{t('nodes.clearGraph')}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Text>{t('common.clearNodes')}</Text>
|
||||
<Text>{t('nodes.clearGraphDesc')}</Text>
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
@@ -83,4 +83,4 @@ const ClearNodesButton = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ClearNodesButton);
|
||||
export default memo(ClearGraphButton);
|
||||
@@ -0,0 +1,161 @@
|
||||
import { FileButton } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import i18n from 'i18n';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
interface JsonFile {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
|
||||
isValid: boolean;
|
||||
message: string;
|
||||
} {
|
||||
// Check if primary keys exist
|
||||
const keys = ['nodes', 'edges', 'viewport'];
|
||||
for (const key of keys) {
|
||||
if (!(key in jsonFile)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Check if nodes and edges are arrays
|
||||
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
|
||||
// Check if data is present in nodes
|
||||
const nodeKeys = ['data', 'type'];
|
||||
const nodeTypes = ['invocation', 'progress_image'];
|
||||
if (jsonFile.nodes.length > 0) {
|
||||
for (const node of jsonFile.nodes) {
|
||||
for (const nodeKey of nodeKeys) {
|
||||
if (!(nodeKey in node)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesUnrecognizedTypes'),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Edge Object
|
||||
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
|
||||
if (jsonFile.edges.length > 0) {
|
||||
for (const edge of jsonFile.edges) {
|
||||
for (const edgeKey of edgeKeys) {
|
||||
if (!(edgeKey in edge)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesBrokenConnections'),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
message: i18n.t('toast.nodesLoaded'),
|
||||
};
|
||||
}
|
||||
|
||||
const LoadGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { fitView } = useReactFlow();
|
||||
|
||||
const uploadedFileRef = useRef<() => void>(null);
|
||||
|
||||
const restoreJSONToEditor = useCallback(
|
||||
(v: File | null) => {
|
||||
if (!v) return;
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const json = reader.result;
|
||||
|
||||
try {
|
||||
const retrievedNodeTree = await JSON.parse(String(json));
|
||||
const { isValid, message } =
|
||||
sanityCheckInvokeAIGraph(retrievedNodeTree);
|
||||
|
||||
if (isValid) {
|
||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||
fitView();
|
||||
|
||||
dispatch(
|
||||
addToast(makeToast({ title: message, status: 'success' }))
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: message,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
// Cleanup
|
||||
reader.abort();
|
||||
} catch (error) {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesNotValidJSON'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
reader.readAsText(v);
|
||||
|
||||
// Cleanup
|
||||
uploadedFileRef.current?.();
|
||||
},
|
||||
[fitView, dispatch, t]
|
||||
);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={uploadedFileRef}
|
||||
accept="application/json"
|
||||
onChange={restoreJSONToEditor}
|
||||
>
|
||||
{(props) => (
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
tooltip={t('nodes.loadGraph')}
|
||||
aria-label={t('nodes.loadGraph')}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</FileButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoadGraphButton);
|
||||
@@ -1,79 +0,0 @@
|
||||
import { FileButton } from '@mantine/core';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
const LoadNodesButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { fitView } = useReactFlow();
|
||||
|
||||
const uploadedFileRef = useRef<() => void>(null);
|
||||
|
||||
const restoreJSONToEditor = useCallback(
|
||||
(v: File | null) => {
|
||||
if (!v) return;
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const json = reader.result;
|
||||
const retrievedNodeTree = await JSON.parse(String(json));
|
||||
|
||||
if (!retrievedNodeTree) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesLoadedFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (retrievedNodeTree) {
|
||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||
fitView();
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
reader.abort();
|
||||
};
|
||||
|
||||
reader.readAsText(v);
|
||||
|
||||
// Cleanup
|
||||
uploadedFileRef.current?.();
|
||||
},
|
||||
[fitView, dispatch, t]
|
||||
);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={uploadedFileRef}
|
||||
accept="application/json"
|
||||
onChange={restoreJSONToEditor}
|
||||
>
|
||||
{(props) => (
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
tooltip={t('nodes.loadNodes')}
|
||||
aria-label={t('nodes.loadNodes')}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</FileButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoadNodesButton);
|
||||
@@ -6,7 +6,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSave } from 'react-icons/fa';
|
||||
|
||||
const SaveNodesButton = () => {
|
||||
const SaveGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const editorInstance = useAppSelector(
|
||||
(state: RootState) => state.nodes.editorInstance
|
||||
@@ -37,12 +37,12 @@ const SaveNodesButton = () => {
|
||||
<IAIIconButton
|
||||
icon={<FaSave />}
|
||||
fontSize={18}
|
||||
tooltip={t('nodes.saveNodes')}
|
||||
aria-label={t('nodes.saveNodes')}
|
||||
tooltip={t('nodes.saveGraph')}
|
||||
aria-label={t('nodes.saveGraph')}
|
||||
onClick={saveEditorToJSON}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SaveNodesButton);
|
||||
export default memo(SaveGraphButton);
|
||||
@@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ClipField: 'clip',
|
||||
VaeField: 'vae',
|
||||
model: 'model',
|
||||
refiner_model: 'refiner_model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
@@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
refiner_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Refiner Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
vae_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
|
||||
@@ -70,6 +70,7 @@ export type FieldType =
|
||||
| 'vae'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'refiner_model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
@@ -100,6 +101,7 @@ export type InputFieldValue =
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| MainModelInputFieldValue
|
||||
| RefinerModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
@@ -128,6 +130,7 @@ export type InputFieldTemplate =
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| RefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
@@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldValue = FieldValueBase & {
|
||||
type: 'refiner_model';
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: VaeModelParam;
|
||||
@@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'refiner_model';
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'vae_model';
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
LoRAModelInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
RefinerModelInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
TypeHints,
|
||||
UNetInputFieldTemplate,
|
||||
@@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
|
||||
const template: RefinerModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'refiner_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVaeModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['refiner_model'].includes(fieldType)) {
|
||||
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
||||
@@ -76,6 +76,10 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'refiner_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
SDXL_LATENTS_TO_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
SDXL_REFINER_MODEL_LOADER,
|
||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const { positivePrompt, negativePrompt } = state.generation;
|
||||
const {
|
||||
refinerModel,
|
||||
refinerAestheticScore,
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
refinerSteps,
|
||||
refinerScheduler,
|
||||
refinerCFGScale,
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
if (!refinerModel) return;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
metadataAccumulator.refiner_start = refinerStart;
|
||||
metadataAccumulator.refiner_steps = refinerSteps;
|
||||
}
|
||||
|
||||
// Unplug SDXL Latents Generation To Latents To Image
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
|
||||
);
|
||||
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === SDXL_MODEL_LOADER &&
|
||||
['vae'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
|
||||
// connect the VAE back to the i2l, which we just removed in the filter
|
||||
// but only if we are doing l2l
|
||||
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
|
||||
type: 'sdxl_refiner_model_loader',
|
||||
id: SDXL_REFINER_MODEL_LOADER,
|
||||
model: refinerModel,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
style: `${positivePrompt} ${positiveStylePrompt}`,
|
||||
aesthetic_score: refinerAestheticScore,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
style: `${negativePrompt} ${negativeStylePrompt}`,
|
||||
aesthetic_score: refinerAestheticScore,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
|
||||
type: 'l2l_sdxl',
|
||||
id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
cfg_scale: refinerCFGScale,
|
||||
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
|
||||
scheduler: refinerScheduler,
|
||||
denoising_start: refinerStart,
|
||||
denoising_end: 1,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: baseNodeId,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
@@ -46,6 +46,7 @@ export const buildLinearImageToImageGraph = (
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
// TODO: add batch functionality
|
||||
@@ -113,6 +114,7 @@ export const buildLinearImageToImageGraph = (
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
@@ -129,6 +131,7 @@ export const buildLinearImageToImageGraph = (
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
||||
@@ -0,0 +1,369 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RESIZE,
|
||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_LATENTS_TO_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
} from './constants';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearSDXLImageToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
initialImage,
|
||||
shouldFitToWidthHeight,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const {
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
shouldUseSDXLRefiner,
|
||||
refinerStart,
|
||||
sdxlImg2ImgDenoisingStrength: strength,
|
||||
} = state.sdxl;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
if (!initialImage) {
|
||||
log.error('No initial image found in state');
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[SDXL_MODEL_LOADER]: {
|
||||
type: 'sdxl_model_loader',
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
use_cpu,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[SDXL_LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l_sdxl',
|
||||
id: SDXL_LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: shouldUseSDXLRefiner
|
||||
? Math.min(refinerStart, 1 - strength)
|
||||
: 1 - strength,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle `fit`
|
||||
if (
|
||||
shouldFitToWidthHeight &&
|
||||
(initialImage.width !== width || initialImage.height !== height)
|
||||
) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.imageName,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
|
||||
image_name: initialImage.imageName,
|
||||
};
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
strength: strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
|
||||
}
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
@@ -0,0 +1,251 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const {
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
shouldUseSDXLRefiner,
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[SDXL_MODEL_LOADER]: {
|
||||
type: 'sdxl_model_loader',
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
width,
|
||||
height,
|
||||
use_cpu,
|
||||
},
|
||||
[SDXL_TEXT_TO_LATENTS]: {
|
||||
type: 't2l_sdxl',
|
||||
id: SDXL_TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
|
||||
}
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
@@ -34,6 +34,7 @@ export const buildLinearTextToImageGraph = (
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@@ -95,6 +96,7 @@ export const buildLinearTextToImageGraph = (
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
||||
@@ -23,8 +23,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
export const REALESRGAN = 'esrgan';
|
||||
export const DIVIDE = 'divide';
|
||||
export const SCALE = 'scale_image';
|
||||
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
|
||||
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
|
||||
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
|
||||
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
|
||||
export const SDXL_REFINER_POSITIVE_CONDITIONING =
|
||||
'sdxl_refiner_positive_conditioning';
|
||||
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
|
||||
'sdxl_refiner_negative_conditioning';
|
||||
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
|
||||
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
|
||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||
export const INPAINT_GRAPH = 'inpaint_graph';
|
||||
|
||||
@@ -13,7 +13,12 @@ import {
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
|
||||
const getReservedFieldNames = (type: string): string[] => {
|
||||
if (type === 'l2i') {
|
||||
return ['id', 'type', 'metadata'];
|
||||
}
|
||||
return ['id', 'type', 'is_intermediate', 'metadata'];
|
||||
};
|
||||
|
||||
const invocationDenylist = [
|
||||
'Graph',
|
||||
@@ -21,11 +26,11 @@ const invocationDenylist = [
|
||||
'MetadataAccumulatorInvocation',
|
||||
];
|
||||
|
||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||
export const parseSchema = (
|
||||
openAPI: OpenAPIV3.Document
|
||||
): Record<string, InvocationTemplate> => {
|
||||
const filteredSchemas = filter(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
openAPI.components!.schemas,
|
||||
openAPI.components?.schemas,
|
||||
(schema, key) =>
|
||||
key.includes('Invocation') &&
|
||||
!key.includes('InvocationOutput') &&
|
||||
@@ -35,21 +40,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
const invocations = filteredSchemas.reduce<
|
||||
Record<string, InvocationTemplate>
|
||||
>((acc, schema) => {
|
||||
// only want SchemaObjects
|
||||
if (isInvocationSchemaObject(schema)) {
|
||||
const type = schema.properties.type.default;
|
||||
const RESERVED_FIELD_NAMES = getReservedFieldNames(type);
|
||||
|
||||
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
||||
|
||||
const typeHints = schema.ui?.type_hints;
|
||||
|
||||
const inputs: Record<string, InputFieldTemplate> = {};
|
||||
|
||||
if (type === 'collect') {
|
||||
const itemProperty = schema.properties[
|
||||
'item'
|
||||
] as InvocationSchemaObject;
|
||||
// Handle the special Collect node
|
||||
const itemProperty = schema.properties.item as InvocationSchemaObject;
|
||||
inputs.item = {
|
||||
type: 'item',
|
||||
name: 'item',
|
||||
@@ -60,10 +61,8 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
default: undefined,
|
||||
};
|
||||
} else if (type === 'iterate') {
|
||||
const itemProperty = schema.properties[
|
||||
'collection'
|
||||
] as InvocationSchemaObject;
|
||||
|
||||
const itemProperty = schema.properties
|
||||
.collection as InvocationSchemaObject;
|
||||
inputs.collection = {
|
||||
type: 'array',
|
||||
name: 'collection',
|
||||
@@ -74,18 +73,18 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
// All other nodes
|
||||
reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
const field: InputFieldTemplate | undefined =
|
||||
buildInputFieldTemplate(property, propertyName, typeHints);
|
||||
|
||||
const field = buildInputFieldTemplate(
|
||||
property,
|
||||
propertyName,
|
||||
typeHints
|
||||
);
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
@@ -97,22 +96,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
}
|
||||
|
||||
const rawOutput = (schema as InvocationSchemaObject).output;
|
||||
|
||||
let outputs: Record<string, OutputFieldTemplate>;
|
||||
|
||||
// some special handling is needed for collect, iterate and range nodes
|
||||
if (type === 'iterate') {
|
||||
// this is guaranteed to be a SchemaObject
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const iterationOutput = openAPI.components!.schemas![
|
||||
const iterationOutput = openAPI.components?.schemas?.[
|
||||
'IterateInvocationOutput'
|
||||
] as OpenAPIV3.SchemaObject;
|
||||
|
||||
outputs = {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: iterationOutput.title ?? '',
|
||||
description: iterationOutput.description ?? '',
|
||||
title: iterationOutput?.title ?? '',
|
||||
description: iterationOutput?.description ?? '',
|
||||
type: 'array',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode';
|
||||
|
||||
const GENERATION_MODE_NAME_MAP = {
|
||||
txt2img: 'Text to Image',
|
||||
img2img: 'Image to Image',
|
||||
inpaint: 'Inpaint',
|
||||
outpaint: 'Inpaint',
|
||||
};
|
||||
|
||||
const GenerationModeStatusText = () => {
|
||||
const generationMode = useCanvasGenerationMode();
|
||||
|
||||
return (
|
||||
<Box>
|
||||
Mode: {generationMode ? GENERATION_MODE_NAME_MAP[generationMode] : '...'}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default GenerationModeStatusText;
|
||||
@@ -4,6 +4,7 @@ import { memo } from 'react';
|
||||
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
|
||||
|
||||
const ParamModelandVAEandScheduler = () => {
|
||||
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
|
||||
@@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
|
||||
<Box w="full">
|
||||
<ParamMainModelSelect />
|
||||
</Box>
|
||||
<Flex gap={3} w="full">
|
||||
{isVaeEnabled && (
|
||||
<Box w="full">
|
||||
<ParamVAEModelSelect />
|
||||
</Box>
|
||||
)}
|
||||
<Box w="full">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
</Flex>
|
||||
<Box w="full">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
{isVaeEnabled && (
|
||||
<Flex w="full" gap={3}>
|
||||
<ParamVAEModelSelect />
|
||||
<ParamVAEPrecision />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -13,8 +13,11 @@ import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@@ -28,7 +31,12 @@ const ParamMainModelSelect = () => {
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
@@ -38,7 +46,10 @@ const ParamMainModelSelect = () => {
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
|
||||
if (
|
||||
!model ||
|
||||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -50,7 +61,7 @@ const ParamMainModelSelect = () => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, activeTabName]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
@@ -86,7 +97,7 @@ const ParamMainModelSelect = () => {
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<Flex w="100%" alignItems="center" gap={3}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
@@ -98,9 +109,11 @@ const ParamMainModelSelect = () => {
|
||||
onChange={handleChangeModel}
|
||||
w="100%"
|
||||
/>
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -32,11 +32,6 @@ export default function ParamSeed() {
|
||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||
onChange={handleChangeSeed}
|
||||
value={seed}
|
||||
formControlProps={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
|
||||
|
||||
const ParamSeedFull = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 4, alignItems: 'center' }}>
|
||||
<Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
|
||||
<ParamSeed />
|
||||
<ParamSeedShuffle />
|
||||
<ParamSeedRandomize />
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
|
||||
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ generation }) => {
|
||||
const { vaePrecision } = generation;
|
||||
return { vaePrecision };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const DATA = ['fp16', 'fp32'];
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { vaePrecision } = useAppSelector(selector);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaePrecisionChanged(v as PrecisionParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label="VAE Precision"
|
||||
value={vaePrecision}
|
||||
data={DATA}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamVAEModelSelect);
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
PositivePromptParam,
|
||||
PrecisionParam,
|
||||
SchedulerParam,
|
||||
SeedParam,
|
||||
StepsParam,
|
||||
@@ -51,6 +52,7 @@ export interface GenerationState {
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
vaePrecision: PrecisionParam;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
@@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
|
||||
verticalSymmetrySteps: 0,
|
||||
model: null,
|
||||
vae: null,
|
||||
vaePrecision: 'fp32',
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
@@ -241,6 +244,9 @@ export const generationSlice = createSlice({
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
state.clipSkip = action.payload;
|
||||
},
|
||||
@@ -327,6 +333,7 @@ export const {
|
||||
shouldUseCpuNoiseChanged,
|
||||
setShouldShowAdvancedOptions,
|
||||
setAspectRatio,
|
||||
vaePrecisionChanged,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export default generationSlice.reducer;
|
||||
|
||||
@@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
|
||||
val: unknown
|
||||
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL positive style prompt parameter
|
||||
*/
|
||||
export const zPositiveStylePromptSDXL = z.string();
|
||||
/**
|
||||
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
|
||||
*/
|
||||
export type PositiveStylePromptSDXLParam = z.infer<
|
||||
typeof zPositiveStylePromptSDXL
|
||||
>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL positive style prompt parameter
|
||||
*/
|
||||
export const isValidSDXLPositiveStylePrompt = (
|
||||
val: unknown
|
||||
): val is PositiveStylePromptSDXLParam =>
|
||||
zPositiveStylePromptSDXL.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL negative style prompt parameter
|
||||
*/
|
||||
export const zNegativeStylePromptSDXL = z.string();
|
||||
/**
|
||||
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
|
||||
*/
|
||||
export type NegativeStylePromptSDXLParam = z.infer<
|
||||
typeof zNegativeStylePromptSDXL
|
||||
>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL negative style prompt parameter
|
||||
*/
|
||||
export const isValidSDXLNegativeStylePrompt = (
|
||||
val: unknown
|
||||
): val is NegativeStylePromptSDXLParam =>
|
||||
zNegativeStylePromptSDXL.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for steps parameter
|
||||
*/
|
||||
@@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||
zStrength.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for a precision parameter
|
||||
*/
|
||||
export const zPrecision = z.enum(['fp16', 'fp32']);
|
||||
/**
|
||||
* Type alias for precision parameter, inferred from its zod schema
|
||||
*/
|
||||
export type PrecisionParam = z.infer<typeof zPrecision>;
|
||||
/**
|
||||
* Validates/type-guards a value as a precision parameter
|
||||
*/
|
||||
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
|
||||
zPrecision.safeParse(val).success;
|
||||
|
||||
// /**
|
||||
// * Zod schema for BaseModelType
|
||||
// */
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl }) => {
|
||||
const { sdxlImg2ImgDenoisingStrength } = sdxl;
|
||||
|
||||
return {
|
||||
sdxlImg2ImgDenoisingStrength,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLImg2ImgDenoisingStrength = () => {
|
||||
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={`${t('parameters.denoisingStrength')}`}
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={sdxlImg2ImgDenoisingStrength}
|
||||
isInteger={false}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLImg2ImgDenoisingStrength);
|
||||
@@ -0,0 +1,149 @@
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ sdxl }, activeTabName) => {
|
||||
return {
|
||||
prompt: sdxl.negativeStylePrompt,
|
||||
activeTabName,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Prompt input text area.
|
||||
*/
|
||||
const ParamSDXLNegativeStyleConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setNegativeStylePromptSDXL(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setNegativeStylePromptSDXL(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
e.preventDefault();
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (isEmbeddingEnabled && e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder="Negative Style Prompt"
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
fontSize="sm"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLNegativeStyleConditioning;
|
||||
@@ -0,0 +1,148 @@
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ sdxl }, activeTabName) => {
|
||||
return {
|
||||
prompt: sdxl.positiveStylePrompt,
|
||||
activeTabName,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Prompt input text area.
|
||||
*/
|
||||
const ParamSDXLPositiveStyleConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setPositiveStylePromptSDXL(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setPositiveStylePromptSDXL(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
e.preventDefault();
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (isEmbeddingEnabled && e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder="Positive Style Prompt"
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLPositiveStyleConditioning;
|
||||
@@ -0,0 +1,48 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
|
||||
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
|
||||
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
|
||||
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
|
||||
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
|
||||
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||
const { shouldUseSliders } = state.ui;
|
||||
return {
|
||||
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
|
||||
shouldUseSliders,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerCollapse = () => {
|
||||
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<IAICollapse label="Refiner" activeLabel={activeLabel}>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<ParamUseSDXLRefiner />
|
||||
<ParamSDXLRefinerModelSelect />
|
||||
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
|
||||
<ParamSDXLRefinerSteps />
|
||||
<ParamSDXLRefinerCFGScale />
|
||||
</Flex>
|
||||
<ParamSDXLRefinerScheduler />
|
||||
<ParamSDXLRefinerAestheticScore />
|
||||
<ParamSDXLRefinerStart />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLRefinerCollapse;
|
||||
@@ -0,0 +1,78 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
|
||||
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const { shouldUseSliders } = ui;
|
||||
const { shouldRandomizeSeed } = generation;
|
||||
|
||||
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
|
||||
|
||||
return { shouldUseSliders, activeLabel };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const SDXLImageToImageTabCoreParameters = () => {
|
||||
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<IAICollapse
|
||||
label={'General'}
|
||||
activeLabel={activeLabel}
|
||||
defaultIsOpen={true}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 3,
|
||||
}}
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
<ParamSize />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
<ParamModelandVAEandScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
<ParamSize />
|
||||
</>
|
||||
)}
|
||||
<ParamSDXLImg2ImgDenoisingStrength />
|
||||
<ImageToImageFit />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SDXLImageToImageTabCoreParameters);
|
||||
@@ -0,0 +1,28 @@
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
||||
|
||||
const SDXLImageToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ProcessButtons />
|
||||
<SDXLImageToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDXLImageToImageTabParameters;
|
||||
@@ -0,0 +1,60 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, hotkeys }) => {
|
||||
const { refinerAestheticScore } = sdxl;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
refinerAestheticScore,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerAestheticScore = () => {
|
||||
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
||||
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerAestheticScore(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerAestheticScore(6)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label="Aesthetic Score"
|
||||
step={shift ? 0.1 : 0.5}
|
||||
min={1}
|
||||
max={10}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerAestheticScore}
|
||||
sliderNumberInputProps={{ max: 10 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerAestheticScore);
|
||||
@@ -0,0 +1,75 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui, hotkeys }) => {
|
||||
const { refinerCFGScale } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
refinerCFGScale,
|
||||
shouldUseSliders,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerCFGScale = () => {
|
||||
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerCFGScale(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerCFGScale(7)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return shouldUseSliders ? (
|
||||
<IAISlider
|
||||
label={t('parameters.cfgScale')}
|
||||
step={shift ? 0.1 : 0.5}
|
||||
min={1}
|
||||
max={20}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerCFGScale}
|
||||
sliderNumberInputProps={{ max: 200 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
) : (
|
||||
<IAINumberInput
|
||||
label={t('parameters.cfgScale')}
|
||||
step={0.5}
|
||||
min={1}
|
||||
max={200}
|
||||
onChange={handleChange}
|
||||
value={refinerCFGScale}
|
||||
isInteger={false}
|
||||
numberInputFieldProps={{ textAlign: 'center' }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerCFGScale);
|
||||
@@ -0,0 +1,111 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => ({ model: state.sdxl.refinerModel }),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
refinerModels?.entities[
|
||||
`${model?.base_model}/main/${model?.model_name}`
|
||||
] ?? null,
|
||||
[refinerModels?.entities, model]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(refinerModelChanged(newModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label="Refiner Model"
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label="Refiner Model"
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
w="100%"
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerModelSelect);
|
||||
@@ -0,0 +1,65 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ ui, sdxl }) => {
|
||||
const { refinerScheduler } = sdxl;
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
value: name,
|
||||
label: label,
|
||||
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
refinerScheduler,
|
||||
data,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerScheduler = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { refinerScheduler, data } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setRefinerScheduler(v as SchedulerParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
w="100%"
|
||||
label={t('parameters.scheduler')}
|
||||
value={refinerScheduler}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
disabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerScheduler);
|
||||
@@ -0,0 +1,53 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl }) => {
|
||||
const { refinerStart } = sdxl;
|
||||
return {
|
||||
refinerStart,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerStart = () => {
|
||||
const { refinerStart } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerStart(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerStart(0.7)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label="Refiner Start"
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerStart}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerStart);
|
||||
@@ -0,0 +1,72 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui }) => {
|
||||
const { refinerSteps } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
|
||||
return {
|
||||
refinerSteps,
|
||||
shouldUseSliders,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerSteps = () => {
|
||||
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setRefinerSteps(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(setRefinerSteps(20));
|
||||
}, [dispatch]);
|
||||
|
||||
return shouldUseSliders ? (
|
||||
<IAISlider
|
||||
label={t('parameters.steps')}
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerSteps}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
sliderNumberInputProps={{ max: 500 }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
) : (
|
||||
<IAINumberInput
|
||||
label={t('parameters.steps')}
|
||||
min={1}
|
||||
max={500}
|
||||
step={1}
|
||||
onChange={handleChange}
|
||||
value={refinerSteps}
|
||||
numberInputFieldProps={{ textAlign: 'center' }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerSteps);
|
||||
@@ -0,0 +1,28 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
export default function ParamUseSDXLRefiner() {
|
||||
const shouldUseSDXLRefiner = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||
);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(setShouldUseSDXLRefiner(e.target.checked));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Use Refiner"
|
||||
isChecked={shouldUseSDXLRefiner}
|
||||
onChange={handleUseSDXLRefinerChange}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
|
||||
const SDXLTextToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDXLTextToImageTabParameters;
|
||||
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import {
|
||||
MainModelParam,
|
||||
NegativeStylePromptSDXLParam,
|
||||
PositiveStylePromptSDXLParam,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { MainModelField } from 'services/api/types';
|
||||
|
||||
type SDXLInitialState = {
|
||||
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||
shouldUseSDXLRefiner: boolean;
|
||||
sdxlImg2ImgDenoisingStrength: number;
|
||||
refinerModel: MainModelField | null;
|
||||
refinerSteps: number;
|
||||
refinerCFGScale: number;
|
||||
refinerScheduler: SchedulerParam;
|
||||
refinerAestheticScore: number;
|
||||
refinerStart: number;
|
||||
};
|
||||
|
||||
const sdxlInitialState: SDXLInitialState = {
|
||||
positiveStylePrompt: '',
|
||||
negativeStylePrompt: '',
|
||||
shouldUseSDXLRefiner: false,
|
||||
sdxlImg2ImgDenoisingStrength: 0.7,
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
refinerCFGScale: 7.5,
|
||||
refinerScheduler: 'euler',
|
||||
refinerAestheticScore: 6,
|
||||
refinerStart: 0.7,
|
||||
};
|
||||
|
||||
const sdxlSlice = createSlice({
|
||||
name: 'sdxl',
|
||||
initialState: sdxlInitialState,
|
||||
reducers: {
|
||||
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||
state.positiveStylePrompt = action.payload;
|
||||
},
|
||||
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||
state.negativeStylePrompt = action.payload;
|
||||
},
|
||||
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldUseSDXLRefiner = action.payload;
|
||||
},
|
||||
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
|
||||
state.sdxlImg2ImgDenoisingStrength = action.payload;
|
||||
},
|
||||
refinerModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<MainModelParam | null>
|
||||
) => {
|
||||
state.refinerModel = action.payload;
|
||||
},
|
||||
setRefinerSteps: (state, action: PayloadAction<number>) => {
|
||||
state.refinerSteps = action.payload;
|
||||
},
|
||||
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
|
||||
state.refinerCFGScale = action.payload;
|
||||
},
|
||||
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||
state.refinerScheduler = action.payload;
|
||||
},
|
||||
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
|
||||
state.refinerAestheticScore = action.payload;
|
||||
},
|
||||
setRefinerStart: (state, action: PayloadAction<number>) => {
|
||||
state.refinerStart = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
setPositiveStylePromptSDXL,
|
||||
setNegativeStylePromptSDXL,
|
||||
setShouldUseSDXLRefiner,
|
||||
setSDXLImg2ImgDenoisingStrength,
|
||||
refinerModelChanged,
|
||||
setRefinerSteps,
|
||||
setRefinerCFGScale,
|
||||
setRefinerScheduler,
|
||||
setRefinerAestheticScore,
|
||||
setRefinerStart,
|
||||
} = sdxlSlice.actions;
|
||||
|
||||
export default sdxlSlice.reducer;
|
||||
@@ -1,5 +1,5 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { InvokeLogLevel } from 'app/logging/logger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
@@ -16,13 +16,16 @@ import {
|
||||
appSocketGraphExecutionStateComplete,
|
||||
appSocketInvocationComplete,
|
||||
appSocketInvocationError,
|
||||
appSocketInvocationRetrievalError,
|
||||
appSocketInvocationStarted,
|
||||
appSocketSessionRetrievalError,
|
||||
appSocketSubscribed,
|
||||
appSocketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../util/makeToast';
|
||||
import { LANGUAGES } from './constants';
|
||||
import { startCase } from 'lodash-es';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@@ -288,25 +291,6 @@ export const systemSlice = createSlice({
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Error
|
||||
*/
|
||||
builder.addCase(appSocketInvocationError, (state) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
// state.currentIteration = 0;
|
||||
// state.totalIterations = 0;
|
||||
state.currentStatusHasSteps = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusError';
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: t('toast.serverError'), status: 'error' })
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Graph Execution State Complete
|
||||
*/
|
||||
@@ -362,7 +346,7 @@ export const systemSlice = createSlice({
|
||||
* Session Invoked - REJECTED
|
||||
* Session Created - REJECTED
|
||||
*/
|
||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = false;
|
||||
state.isCancelScheduled = false;
|
||||
@@ -372,7 +356,35 @@ export const systemSlice = createSlice({
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: t('toast.serverError'), status: 'error' })
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description:
|
||||
action.payload?.status === 422 ? 'Validation Error' : undefined,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Any server error
|
||||
*/
|
||||
builder.addMatcher(isAnyServerError, (state, action) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
// state.currentIteration = 0;
|
||||
// state.totalIterations = 0;
|
||||
state.currentStatusHasSteps = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusError';
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: startCase(action.payload.data.error_type),
|
||||
})
|
||||
);
|
||||
});
|
||||
},
|
||||
@@ -400,3 +412,9 @@ export const {
|
||||
} = systemSlice.actions;
|
||||
|
||||
export default systemSlice.reducer;
|
||||
|
||||
const isAnyServerError = isAnyOf(
|
||||
appSocketInvocationError,
|
||||
appSocketSessionRetrievalError,
|
||||
appSocketInvocationRetrievalError
|
||||
);
|
||||
|
||||
@@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@@ -172,13 +172,22 @@ const InvokeTabs = () => {
|
||||
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
|
||||
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
|
||||
|
||||
const handleTabChange = useCallback(
|
||||
(index: number) => {
|
||||
const activeTabName = tabMap[index];
|
||||
if (!activeTabName) {
|
||||
return;
|
||||
}
|
||||
dispatch(setActiveTab(activeTabName));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tabs
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
onChange={(index: number) => {
|
||||
dispatch(setActiveTab(index));
|
||||
}}
|
||||
onChange={handleTabChange}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
gap: 4,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
|
||||
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import {
|
||||
ImperativePanelGroupHandle,
|
||||
@@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
|
||||
const ImageToImageTab = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||
|
||||
const handleDoubleClickHandle = useCallback(() => {
|
||||
if (!panelGroupRef.current) {
|
||||
@@ -28,7 +31,11 @@ const ImageToImageTab = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||
<ParametersPinnedWrapper>
|
||||
<ImageToImageTabParameters />
|
||||
{model && model.base_model === 'sdxl' ? (
|
||||
<SDXLImageToImageTabParameters />
|
||||
) : (
|
||||
<ImageToImageTabParameters />
|
||||
)}
|
||||
</ParametersPinnedWrapper>
|
||||
<Box sx={{ w: 'full', h: 'full' }}>
|
||||
<PanelGroup
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
useImportMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function FoundModelsList() {
|
||||
const searchFolder = useAppSelector(
|
||||
@@ -24,7 +25,7 @@ export default function FoundModelsList() {
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
|
||||
// Get paths of models that are already installed
|
||||
const { data: installedModels } = useGetMainModelsQuery();
|
||||
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
// Get all model paths from a given directory
|
||||
const { foundModels, alreadyInstalled, filteredModels } =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
@@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useMergeMainModelsMutation,
|
||||
@@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery();
|
||||
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import {
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(undefined, {
|
||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
useGetMainModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
@@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
}),
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
|
||||
import { memo } from 'react';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import TextToImageTabMain from './TextToImageTabMain';
|
||||
import TextToImageTabParameters from './TextToImageTabParameters';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||
return (
|
||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||
<ParametersPinnedWrapper>
|
||||
<TextToImageTabParameters />
|
||||
{model && model.base_model === 'sdxl' ? (
|
||||
<TextToImageSDXLTabParameters />
|
||||
) : (
|
||||
<TextToImageTabParameters />
|
||||
)}
|
||||
</ParametersPinnedWrapper>
|
||||
<TextToImageTabMain />
|
||||
</Flex>
|
||||
|
||||
@@ -26,7 +26,7 @@ export const uiSlice = createSlice({
|
||||
name: 'ui',
|
||||
initialState: initialUIState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
|
||||
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||
setActiveTabReducer(state, action.payload);
|
||||
},
|
||||
setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => {
|
||||
|
||||
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { BaseModelType } from './types';
|
||||
|
||||
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
];
|
||||
|
||||
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
];
|
||||
|
||||
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
||||
@@ -144,8 +144,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
@@ -187,7 +198,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
importMainModels: build.mutation<
|
||||
ImportMainModelResponse,
|
||||
@@ -200,7 +214,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
@@ -210,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
deleteMainModels: build.mutation<
|
||||
DeleteMainModelResponse,
|
||||
@@ -222,7 +242,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
convertMainModels: build.mutation<
|
||||
ConvertMainModelResponse,
|
||||
@@ -235,7 +258,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
params: params,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
query: ({ base_model, body }) => {
|
||||
@@ -245,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
query: () => {
|
||||
@@ -254,7 +283,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
method: 'POST',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
||||
@@ -1014,6 +1014,11 @@ export type components = {
|
||||
* @description The LoRAs used for inference
|
||||
*/
|
||||
loras: (components["schemas"]["LoRAMetadataField"])[];
|
||||
/**
|
||||
* Vae
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
/**
|
||||
* Strength
|
||||
* @description The strength used for latents-to-latents
|
||||
@@ -1025,10 +1030,45 @@ export type components = {
|
||||
*/
|
||||
init_image?: string;
|
||||
/**
|
||||
* Vae
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
* Positive Style Prompt
|
||||
* @description The positive style prompt parameter
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
positive_style_prompt?: string;
|
||||
/**
|
||||
* Negative Style Prompt
|
||||
* @description The negative style prompt parameter
|
||||
*/
|
||||
negative_style_prompt?: string;
|
||||
/**
|
||||
* Refiner Model
|
||||
* @description The SDXL Refiner model used
|
||||
*/
|
||||
refiner_model?: components["schemas"]["MainModelField"];
|
||||
/**
|
||||
* Refiner Cfg Scale
|
||||
* @description The classifier-free guidance scale parameter used for the refiner
|
||||
*/
|
||||
refiner_cfg_scale?: number;
|
||||
/**
|
||||
* Refiner Steps
|
||||
* @description The number of steps used for the refiner
|
||||
*/
|
||||
refiner_steps?: number;
|
||||
/**
|
||||
* Refiner Scheduler
|
||||
* @description The scheduler used for the refiner
|
||||
*/
|
||||
refiner_scheduler?: string;
|
||||
/**
|
||||
* Refiner Aesthetic Store
|
||||
* @description The aesthetic score used for the refiner
|
||||
*/
|
||||
refiner_aesthetic_store?: number;
|
||||
/**
|
||||
* Refiner Start
|
||||
* @description The start value used for refiner denoising
|
||||
*/
|
||||
refiner_start?: number;
|
||||
};
|
||||
/**
|
||||
* CvInpaintInvocation
|
||||
@@ -3268,6 +3308,46 @@ export type components = {
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
/**
|
||||
* Positive Style Prompt
|
||||
* @description The positive style prompt parameter
|
||||
*/
|
||||
positive_style_prompt?: string;
|
||||
/**
|
||||
* Negative Style Prompt
|
||||
* @description The negative style prompt parameter
|
||||
*/
|
||||
negative_style_prompt?: string;
|
||||
/**
|
||||
* Refiner Model
|
||||
* @description The SDXL Refiner model used
|
||||
*/
|
||||
refiner_model?: components["schemas"]["MainModelField"];
|
||||
/**
|
||||
* Refiner Cfg Scale
|
||||
* @description The classifier-free guidance scale parameter used for the refiner
|
||||
*/
|
||||
refiner_cfg_scale?: number;
|
||||
/**
|
||||
* Refiner Steps
|
||||
* @description The number of steps used for the refiner
|
||||
*/
|
||||
refiner_steps?: number;
|
||||
/**
|
||||
* Refiner Scheduler
|
||||
* @description The scheduler used for the refiner
|
||||
*/
|
||||
refiner_scheduler?: string;
|
||||
/**
|
||||
* Refiner Aesthetic Store
|
||||
* @description The aesthetic score used for the refiner
|
||||
*/
|
||||
refiner_aesthetic_store?: number;
|
||||
/**
|
||||
* Refiner Start
|
||||
* @description The start value used for refiner denoising
|
||||
*/
|
||||
refiner_start?: number;
|
||||
};
|
||||
/**
|
||||
* MetadataAccumulatorOutput
|
||||
@@ -5355,6 +5435,12 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
@@ -5367,12 +5453,6 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
||||
@@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required<
|
||||
>;
|
||||
|
||||
type CreateSessionThunkConfig = {
|
||||
rejectValue: { arg: CreateSessionArg; error: unknown };
|
||||
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk<
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
@@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: InvokedSessionArg;
|
||||
error: unknown;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk<
|
||||
|
||||
if (error) {
|
||||
if (isErrorWithStatus(error) && error.status === 403) {
|
||||
return rejectWithValue({ arg, error: (error as any).body.detail });
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
return rejectWithValue({ arg, error });
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ import {
|
||||
GraphExecutionStateCompleteEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationRetrievalErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelLoadCompletedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
SessionRetrievalErrorEvent,
|
||||
} from 'services/events/types';
|
||||
|
||||
// Create actions for each socket
|
||||
@@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{
|
||||
export const appSocketModelLoadCompleted = createAction<{
|
||||
data: ModelLoadCompletedEvent;
|
||||
}>('socket/appSocketModelLoadCompleted');
|
||||
|
||||
/**
|
||||
* Socket.IO Session Retrieval Error
|
||||
*
|
||||
* Do not use. Only for use in middleware.
|
||||
*/
|
||||
export const socketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/socketSessionRetrievalError');
|
||||
|
||||
/**
|
||||
* App-level Session Retrieval Error
|
||||
*/
|
||||
export const appSocketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/appSocketSessionRetrievalError');
|
||||
|
||||
/**
|
||||
* Socket.IO Invocation Retrieval Error
|
||||
*
|
||||
* Do not use. Only for use in middleware.
|
||||
*/
|
||||
export const socketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/socketInvocationRetrievalError');
|
||||
|
||||
/**
|
||||
* App-level Invocation Retrieval Error
|
||||
*/
|
||||
export const appSocketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/appSocketInvocationRetrievalError');
|
||||
|
||||
@@ -87,6 +87,7 @@ export type InvocationErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
@@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = {
|
||||
graph_execution_state_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `session_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type SessionRetrievalErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationRetrievalErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
export type ClientEmitSubscribe = {
|
||||
session: string;
|
||||
};
|
||||
@@ -128,6 +152,8 @@ export type ServerToClientEvents = {
|
||||
) => void;
|
||||
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
|
||||
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
|
||||
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
|
||||
};
|
||||
|
||||
export type ClientToServerEvents = {
|
||||
|
||||
@@ -11,9 +11,11 @@ import {
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationRetrievalError,
|
||||
socketInvocationStarted,
|
||||
socketModelLoadCompleted,
|
||||
socketModelLoadStarted,
|
||||
socketSessionRetrievalError,
|
||||
socketSubscribed,
|
||||
} from '../actions';
|
||||
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
||||
@@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Session retrieval error
|
||||
*/
|
||||
socket.on('session_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketSessionRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation retrieval error
|
||||
*/
|
||||
socket.on('invocation_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketInvocationRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user