mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui, nodes): models
This commit is contained in:
@@ -76,7 +76,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model_name = model["model_name"]
|
||||
self.model = model["model_name"]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
|
||||
@@ -10,10 +10,11 @@ import galleryReducer, {
|
||||
GalleryState,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer, {
|
||||
resultsAdapter,
|
||||
ResultsState,
|
||||
} from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import uploadsReducer, {
|
||||
UploadsState,
|
||||
} from 'features/gallery/store/uploadsSlice';
|
||||
import lightboxReducer, {
|
||||
LightboxState,
|
||||
} from 'features/lightbox/store/lightboxSlice';
|
||||
@@ -25,12 +26,13 @@ import postprocessingReducer, {
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import modelsReducer, { ModelsState } from 'features/system/store/modelSlice';
|
||||
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import { CanvasState } from 'features/canvas/store/canvasTypes';
|
||||
import { UIState } from 'features/ui/store/uiTypes';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@@ -138,21 +140,21 @@ resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
|
||||
*
|
||||
* Currently blacklisting uploads slice entirely, see persist config below
|
||||
*/
|
||||
const uploadsBlacklist: (keyof NodesState)[] = [];
|
||||
const uploadsBlacklist: (keyof UploadsState)[] = [];
|
||||
|
||||
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Models slice persist blacklist
|
||||
*/
|
||||
const modelsBlacklist: (keyof NodesState)[] = [];
|
||||
const modelsBlacklist: (keyof ModelsState)[] = ['entities', 'ids'];
|
||||
|
||||
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* UI slice persist blacklist
|
||||
*/
|
||||
const uiBlacklist: (keyof NodesState)[] = [];
|
||||
const uiBlacklist: (keyof UIState)[] = [];
|
||||
|
||||
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ export const buildImg2ImgNode = (
|
||||
const { generation, system, models } = state;
|
||||
|
||||
const { shouldDisplayInProgressType } = system;
|
||||
const { currentModel: model } = models;
|
||||
const { selectedModelName } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
@@ -38,27 +38,29 @@ export const buildImg2ImgNode = (
|
||||
throw 'no initial image';
|
||||
}
|
||||
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale: cfgScale,
|
||||
scheduler: sampler as ImageToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
image: {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
const imageToImageNode: ImageToImageInvocation = {
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale: cfgScale,
|
||||
scheduler: sampler as ImageToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model: selectedModelName,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
image: {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
};
|
||||
|
||||
return {
|
||||
[nodeId]: imageToImageNode,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ export const buildTxt2ImgNode = (
|
||||
const { generation, system, models } = state;
|
||||
|
||||
const { shouldDisplayInProgressType } = system;
|
||||
const { currentModel: model } = models;
|
||||
const { selectedModelName } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
@@ -24,20 +24,22 @@ export const buildTxt2ImgNode = (
|
||||
} = generation;
|
||||
|
||||
// missing fields in TextToImageInvocation: strength, hires_fix
|
||||
const textToImageNode: TextToImageInvocation = {
|
||||
id: nodeId,
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
scheduler: sampler as TextToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model: selectedModelName,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
};
|
||||
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
scheduler: sampler as TextToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
},
|
||||
[nodeId]: textToImageNode,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -22,18 +22,19 @@ type AdditionalUploadsState = {
|
||||
nextPage: number;
|
||||
};
|
||||
|
||||
export type UploadssState = ReturnType<
|
||||
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
|
||||
>;
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
const initialUploadsState =
|
||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
nextPage: 0,
|
||||
isLoading: false,
|
||||
}),
|
||||
});
|
||||
|
||||
export type UploadsState = typeof initialUploadsState;
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: initialUploadsState,
|
||||
reducers: {
|
||||
uploadAdded: uploadsAdapter.addOne,
|
||||
},
|
||||
|
||||
@@ -4,14 +4,19 @@ import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { ModelInputField } from 'features/nodes/types';
|
||||
import {
|
||||
selectModelsById,
|
||||
selectModelsIds,
|
||||
} from 'features/system/store/modelSlice';
|
||||
import { isEqual, map } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const availableModelsSelector = createSelector(
|
||||
(state: RootState) => state.models.modelList,
|
||||
(modelList) => {
|
||||
return map(modelList, (_, name) => name);
|
||||
[selectModelsIds],
|
||||
(allModelNames) => {
|
||||
return { allModelNames };
|
||||
// return map(modelList, (_, name) => name);
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@@ -27,7 +32,7 @@ export const ModelInputFieldComponent = (
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const availableModels = useAppSelector(availableModelsSelector);
|
||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
@@ -41,7 +46,7 @@ export const ModelInputFieldComponent = (
|
||||
|
||||
return (
|
||||
<Select onChange={handleValueChanged} value={field.value}>
|
||||
{availableModels.map((option) => (
|
||||
{allModelNames.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { isEqual, map } from 'lodash';
|
||||
import { isEqual } from 'lodash';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { modelSelector } from '../store/modelSelectors';
|
||||
import { setCurrentModel } from '../store/modelSlice';
|
||||
import {
|
||||
modelSelected,
|
||||
selectedModelSelector,
|
||||
selectModelsIds,
|
||||
} from '../store/modelSlice';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
const selector = createSelector(
|
||||
[modelSelector],
|
||||
(model) => {
|
||||
const { modelList, currentModel } = model;
|
||||
const models = map(modelList, (model, key) => key);
|
||||
return { models, currentModel, modelList };
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const selectedModel = selectedModelSelector(state);
|
||||
const allModelNames = selectModelsIds(state);
|
||||
return {
|
||||
allModelNames,
|
||||
selectedModel,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@@ -26,12 +33,10 @@ const selector = createSelector(
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { models, currentModel, modelList } = useAppSelector(selector);
|
||||
const { allModelNames, selectedModel } = useAppSelector(selector);
|
||||
const handleChangeModel = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(setCurrentModel(e.target.value));
|
||||
dispatch(modelSelected(e.target.value));
|
||||
};
|
||||
const currentModelDescription =
|
||||
currentModel && modelList[currentModel].description;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -42,9 +47,9 @@ const ModelSelect = () => {
|
||||
<IAISelect
|
||||
style={{ fontSize: 'sm' }}
|
||||
aria-label={t('accessibility.modelSelect')}
|
||||
tooltip={currentModelDescription}
|
||||
value={currentModel}
|
||||
validValues={models}
|
||||
tooltip={selectedModel?.description || ''}
|
||||
value={selectedModel?.name || undefined}
|
||||
validValues={allModelNames}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
@@ -1,27 +1,37 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { ModelsList } from 'services/api';
|
||||
import { RootState } from 'app/store';
|
||||
import { keys, sample } from 'lodash';
|
||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
|
||||
export interface ModelState {
|
||||
modelList: ModelsList['models'];
|
||||
currentModel?: string;
|
||||
}
|
||||
|
||||
const initialModelState: ModelState = {
|
||||
modelList: {},
|
||||
currentModel: undefined,
|
||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const modelSlice = createSlice({
|
||||
name: 'model',
|
||||
initialState: initialModelState,
|
||||
export const modelsAdapter = createEntityAdapter<Model>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
type AdditionalModelsState = {
|
||||
selectedModelName: string;
|
||||
};
|
||||
|
||||
export const initialModelsState =
|
||||
modelsAdapter.getInitialState<AdditionalModelsState>({
|
||||
selectedModelName: '',
|
||||
});
|
||||
|
||||
export type ModelsState = typeof initialModelsState;
|
||||
|
||||
export const modelsSlice = createSlice({
|
||||
name: 'models',
|
||||
initialState: initialModelsState,
|
||||
reducers: {
|
||||
setModelList: (state, action: PayloadAction<ModelsList['models']>) => {
|
||||
state.modelList = action.payload;
|
||||
},
|
||||
setCurrentModel: (state, action: PayloadAction<string>) => {
|
||||
state.currentModel = action.payload;
|
||||
modelAdded: modelsAdapter.upsertOne,
|
||||
modelSelected: (state, action: PayloadAction<string>) => {
|
||||
state.selectedModelName = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@@ -29,12 +39,42 @@ export const modelSlice = createSlice({
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
const models = action.payload.models;
|
||||
state.modelList = models;
|
||||
const models = action.payload;
|
||||
modelsAdapter.setAll(state, models);
|
||||
|
||||
// If the current selected model is `''` or isn't actually in the list of models,
|
||||
// choose a random model
|
||||
if (
|
||||
!state.selectedModelName ||
|
||||
!keys(models).includes(state.selectedModelName)
|
||||
) {
|
||||
const randomModel = sample(models);
|
||||
|
||||
if (randomModel) {
|
||||
state.selectedModelName = randomModel.name;
|
||||
} else {
|
||||
state.selectedModelName = '';
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const { setModelList, setCurrentModel } = modelSlice.actions;
|
||||
export const selectedModelSelector = (state: RootState) => {
|
||||
const { selectedModelName } = state.models;
|
||||
const selectedModel = selectModelsById(state, selectedModelName);
|
||||
|
||||
export default modelSlice.reducer;
|
||||
return selectedModel ?? null;
|
||||
};
|
||||
|
||||
export const {
|
||||
selectAll: selectModelsAll,
|
||||
selectById: selectModelsById,
|
||||
selectEntities: selectModelsEntities,
|
||||
selectIds: selectModelsIds,
|
||||
selectTotal: selectModelsTotal,
|
||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
||||
|
||||
export const { modelAdded, modelSelected } = modelsSlice.actions;
|
||||
|
||||
export default modelsSlice.reducer;
|
||||
|
||||
@@ -101,7 +101,7 @@ export const socketMiddleware = () => {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
|
||||
if (!models.modelList.length) {
|
||||
if (!models.ids.length) {
|
||||
dispatch(receivedModels());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { createAppAsyncThunk } from 'app/storeUtils';
|
||||
import { Model } from 'features/system/store/modelSlice';
|
||||
import { reduce } from 'lodash';
|
||||
import { ModelsService } from 'services/api';
|
||||
|
||||
export const IMAGES_PER_PAGE = 20;
|
||||
@@ -7,7 +9,16 @@ export const receivedModels = createAppAsyncThunk(
|
||||
'models/receivedModels',
|
||||
async (_arg) => {
|
||||
const response = await ModelsService.listModels();
|
||||
const deserializedModels = reduce(
|
||||
response.models,
|
||||
(modelsAccumulator, model, modelName) => {
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
|
||||
return response;
|
||||
return modelsAccumulator;
|
||||
},
|
||||
{} as Record<string, Model>
|
||||
);
|
||||
|
||||
return deserializedModels;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -10,8 +10,6 @@ export const deserializeImageResponse = (
|
||||
const { image_name, image_type, image_url, metadata, thumbnail_url } =
|
||||
imageResponse;
|
||||
|
||||
const { width, height, timestamp, invokeai } = metadata;
|
||||
|
||||
// TODO: parse metadata - just leaving it as-is for now
|
||||
|
||||
return {
|
||||
@@ -19,7 +17,6 @@ export const deserializeImageResponse = (
|
||||
type: image_type,
|
||||
url: image_url,
|
||||
thumbnail: thumbnail_url,
|
||||
timestamp,
|
||||
metadata: invokeai,
|
||||
metadata,
|
||||
};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user