feat(ui, nodes): models

This commit is contained in:
psychedelicious
2023-04-14 15:07:16 +10:00
parent 65f2a7ea31
commit 53a1a3eb61
11 changed files with 164 additions and 99 deletions

View File

@@ -76,7 +76,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model_name = model["model_name"]
self.model = model["model_name"]
outputs = Txt2Img(model).generate(
prompt=self.prompt,

View File

@@ -10,10 +10,11 @@ import galleryReducer, {
GalleryState,
} from 'features/gallery/store/gallerySlice';
import resultsReducer, {
resultsAdapter,
ResultsState,
} from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import uploadsReducer, {
UploadsState,
} from 'features/gallery/store/uploadsSlice';
import lightboxReducer, {
LightboxState,
} from 'features/lightbox/store/lightboxSlice';
@@ -25,12 +26,13 @@ import postprocessingReducer, {
} from 'features/parameters/store/postprocessingSlice';
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import modelsReducer, { ModelsState } from 'features/system/store/modelSlice';
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { CanvasState } from 'features/canvas/store/canvasTypes';
import { UIState } from 'features/ui/store/uiTypes';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@@ -138,21 +140,21 @@ resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
*
* Currently blacklisting uploads slice entirely, see persist config below
*/
const uploadsBlacklist: (keyof NodesState)[] = [];
const uploadsBlacklist: (keyof UploadsState)[] = [];
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
/**
* Models slice persist blacklist
*/
const modelsBlacklist: (keyof NodesState)[] = [];
const modelsBlacklist: (keyof ModelsState)[] = ['entities', 'ids'];
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
/**
* UI slice persist blacklist
*/
const uiBlacklist: (keyof NodesState)[] = [];
const uiBlacklist: (keyof UIState)[] = [];
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);

View File

@@ -15,7 +15,7 @@ export const buildImg2ImgNode = (
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const { selectedModelName } = models;
const {
prompt,
@@ -38,27 +38,29 @@ export const buildImg2ImgNode = (
throw 'no initial image';
}
return {
[nodeId]: {
id: nodeId,
type: 'img2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model: selectedModelName,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
};
return {
[nodeId]: imageToImageNode,
};
};

View File

@@ -9,7 +9,7 @@ export const buildTxt2ImgNode = (
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const { selectedModelName } = models;
const {
prompt,
@@ -24,20 +24,22 @@ export const buildTxt2ImgNode = (
} = generation;
// missing fields in TextToImageInvocation: strength, hires_fix
const textToImageNode: TextToImageInvocation = {
id: nodeId,
type: 'txt2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model: selectedModelName,
progress_images: shouldDisplayInProgressType === 'full-res',
};
return {
[nodeId]: {
id: nodeId,
type: 'txt2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
},
[nodeId]: textToImageNode,
};
};

View File

@@ -22,18 +22,19 @@ type AdditionalUploadsState = {
nextPage: number;
};
export type UploadssState = ReturnType<
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
>;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
}),
});
export type UploadsState = typeof initialUploadsState;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadAdded: uploadsAdapter.addOne,
},

View File

@@ -4,14 +4,19 @@ import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { ModelInputField } from 'features/nodes/types';
import {
selectModelsById,
selectModelsIds,
} from 'features/system/store/modelSlice';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
(state: RootState) => state.models.modelList,
(modelList) => {
return map(modelList, (_, name) => name);
[selectModelsIds],
(allModelNames) => {
return { allModelNames };
// return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
@@ -27,7 +32,7 @@ export const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const availableModels = useAppSelector(availableModelsSelector);
const { allModelNames } = useAppSelector(availableModelsSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
@@ -41,7 +46,7 @@ export const ModelInputFieldComponent = (
return (
<Select onChange={handleValueChanged} value={field.value}>
{availableModels.map((option) => (
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}
</Select>

View File

@@ -1,20 +1,27 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent } from 'react';
import { isEqual, map } from 'lodash';
import { isEqual } from 'lodash';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import { modelSelector } from '../store/modelSelectors';
import { setCurrentModel } from '../store/modelSlice';
import {
modelSelected,
selectedModelSelector,
selectModelsIds,
} from '../store/modelSlice';
import { RootState } from 'app/store';
const selector = createSelector(
[modelSelector],
(model) => {
const { modelList, currentModel } = model;
const models = map(modelList, (model, key) => key);
return { models, currentModel, modelList };
[(state: RootState) => state],
(state) => {
const selectedModel = selectedModelSelector(state);
const allModelNames = selectModelsIds(state);
return {
allModelNames,
selectedModel,
};
},
{
memoizeOptions: {
@@ -26,12 +33,10 @@ const selector = createSelector(
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { models, currentModel, modelList } = useAppSelector(selector);
const { allModelNames, selectedModel } = useAppSelector(selector);
const handleChangeModel = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(setCurrentModel(e.target.value));
dispatch(modelSelected(e.target.value));
};
const currentModelDescription =
currentModel && modelList[currentModel].description;
return (
<Flex
@@ -42,9 +47,9 @@ const ModelSelect = () => {
<IAISelect
style={{ fontSize: 'sm' }}
aria-label={t('accessibility.modelSelect')}
tooltip={currentModelDescription}
value={currentModel}
validValues={models}
tooltip={selectedModel?.description || ''}
value={selectedModel?.name || undefined}
validValues={allModelNames}
onChange={handleChangeModel}
/>
</Flex>

View File

@@ -1,27 +1,37 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createEntityAdapter, PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { ModelsList } from 'services/api';
import { RootState } from 'app/store';
import { keys, sample } from 'lodash';
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export interface ModelState {
modelList: ModelsList['models'];
currentModel?: string;
}
const initialModelState: ModelState = {
modelList: {},
currentModel: undefined,
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
name: string;
};
export const modelSlice = createSlice({
name: 'model',
initialState: initialModelState,
export const modelsAdapter = createEntityAdapter<Model>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
type AdditionalModelsState = {
selectedModelName: string;
};
export const initialModelsState =
modelsAdapter.getInitialState<AdditionalModelsState>({
selectedModelName: '',
});
export type ModelsState = typeof initialModelsState;
export const modelsSlice = createSlice({
name: 'models',
initialState: initialModelsState,
reducers: {
setModelList: (state, action: PayloadAction<ModelsList['models']>) => {
state.modelList = action.payload;
},
setCurrentModel: (state, action: PayloadAction<string>) => {
state.currentModel = action.payload;
modelAdded: modelsAdapter.upsertOne,
modelSelected: (state, action: PayloadAction<string>) => {
state.selectedModelName = action.payload;
},
},
extraReducers(builder) {
@@ -29,12 +39,42 @@ export const modelSlice = createSlice({
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload.models;
state.modelList = models;
const models = action.payload;
modelsAdapter.setAll(state, models);
// If the current selected model is `''` or isn't actually in the list of models,
// choose a random model
if (
!state.selectedModelName ||
!keys(models).includes(state.selectedModelName)
) {
const randomModel = sample(models);
if (randomModel) {
state.selectedModelName = randomModel.name;
} else {
state.selectedModelName = '';
}
}
});
},
});
export const { setModelList, setCurrentModel } = modelSlice.actions;
export const selectedModelSelector = (state: RootState) => {
const { selectedModelName } = state.models;
const selectedModel = selectModelsById(state, selectedModelName);
export default modelSlice.reducer;
return selectedModel ?? null;
};
export const {
selectAll: selectModelsAll,
selectById: selectModelsById,
selectEntities: selectModelsEntities,
selectIds: selectModelsIds,
selectTotal: selectModelsTotal,
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
export const { modelAdded, modelSelected } = modelsSlice.actions;
export default modelsSlice.reducer;

View File

@@ -101,7 +101,7 @@ export const socketMiddleware = () => {
dispatch(receivedUploadImagesPage());
}
if (!models.modelList.length) {
if (!models.ids.length) {
dispatch(receivedModels());
}

View File

@@ -1,4 +1,6 @@
import { createAppAsyncThunk } from 'app/storeUtils';
import { Model } from 'features/system/store/modelSlice';
import { reduce } from 'lodash';
import { ModelsService } from 'services/api';
export const IMAGES_PER_PAGE = 20;
@@ -7,7 +9,16 @@ export const receivedModels = createAppAsyncThunk(
'models/receivedModels',
async (_arg) => {
const response = await ModelsService.listModels();
const deserializedModels = reduce(
response.models,
(modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName };
return response;
return modelsAccumulator;
},
{} as Record<string, Model>
);
return deserializedModels;
}
);

View File

@@ -10,8 +10,6 @@ export const deserializeImageResponse = (
const { image_name, image_type, image_url, metadata, thumbnail_url } =
imageResponse;
const { width, height, timestamp, invokeai } = metadata;
// TODO: parse metadata - just leaving it as-is for now
return {
@@ -19,7 +17,6 @@ export const deserializeImageResponse = (
type: image_type,
url: image_url,
thumbnail: thumbnail_url,
timestamp,
metadata: invokeai,
metadata,
};
};