Compare commits

...

13 Commits

Author SHA1 Message Date
Mary Hipp
f0bfa7f0e0 add button to set default settings in parameters 2024-02-29 16:24:10 -05:00
Mary Hipp
c46b2b6fa6 Merge branch 'maryhipp/default-model-settings' of https://github.com/invoke-ai/InvokeAI into maryhipp/default-model-settings 2024-02-29 14:50:37 -05:00
Mary Hipp
058cc715d4 hook MM default model settings up to API 2024-02-29 14:50:31 -05:00
maryhipp
f69e3ee01c add default_settings to metadata 2024-02-29 14:49:37 -05:00
Mary Hipp
6e0665e3d7 UI for configuring default settings for models' 2024-02-29 13:04:53 -05:00
Mary Hipp
5a35550144 add scheduler and vaePrecision to config 2024-02-29 13:02:46 -05:00
Mary Hipp
8926a1a424 lint fix 2024-02-28 12:04:51 -05:00
Mary Hipp
8566c1c7ff cleanup 2024-02-28 11:49:56 -05:00
Mary Hipp
6eb4c1ccb6 adapt embedding popover to work for trigger phrases also 2024-02-28 11:45:34 -05:00
maryhipp
ef474a3196 what have i done 2024-02-27 15:41:03 -05:00
maryhipp
16b3718d6a cleanup 2024-02-27 15:39:34 -05:00
maryhipp
30228ce2a4 fix merge 2024-02-27 15:35:27 -05:00
Mary Hipp
efb5f2d202 UI in MM to create trigger phrases 2024-02-27 15:33:11 -05:00
42 changed files with 1420 additions and 305 deletions

View File

@@ -5,6 +5,7 @@ import pathlib
import shutil
from hashlib import sha1
from random import randbytes
import traceback
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
@@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@@ -242,6 +245,47 @@ async def get_model_metadata(
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.trigger_phrases:
original_metadata.trigger_phrases = changes.trigger_phrases
if changes.default_settings:
original_metadata.default_settings = changes.default_settings
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases, default_settings=changes.default_settings))
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",

View File

@@ -4,9 +4,27 @@ Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `trigger_phrases`: the list of trigger phrases for this model
- `default_settings`: the user-configured default settings for this model
"""
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
"""The model's list of trigger phrases"""
default_settings: Optional[ModelDefaultSettings] = Field(default=None, description="The user-configured default settings for this model")
"""The user-configured default settings for this model"""
class ModelMetadataStoreBase(ABC):

View File

@@ -115,6 +115,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
except Exception as e:
raise e
return self.get_metadata(model_key)
@@ -179,44 +181,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
trigger_phrases=version_json["trainedWords"],
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:

View File

@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager import ModelRepoVariant
@@ -68,12 +69,22 @@ class RemoteModelFile(BaseModel):
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(description="default settings for this model", default=None)
class BaseMetadata(ModelMetadataBase):

View File

@@ -78,6 +78,7 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"add": "Add",
"advanced": "Advanced",
"advancedOptions": "Advanced Options",
"ai": "ai",
@@ -303,6 +304,12 @@
"method": "High Resolution Fix Method"
}
},
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noPromptTriggers": "No triggers available",
"noMatchingTriggers": "No matching triggers"
},
"embedding": {
"addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:",
@@ -734,6 +741,8 @@
"customConfig": "Custom Config",
"customConfigFileLocation": "Custom Config File Location",
"customSaveLocation": "Custom Save Location",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
@@ -768,6 +777,7 @@
"mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model",
"modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed",
@@ -839,9 +849,12 @@
"statusConverting": "Converting",
"syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"updateModel": "Update Model",
"useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
"v1": "v1",
"v2_768": "v2 (768px)",
"v2_base": "v2 (512px)",

View File

@@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
// Dynamic prompts
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening)

View File

@@ -0,0 +1,88 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import { setCfgRescaleMultiplier, setCfgScale, setScheduler, setSteps, vaePrecisionChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterPrecision, isParameterScheduler, isParameterSteps, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const currentModel = state.generation.model;
if (!currentModel) {
return
}
const metadata = await dispatch(
modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)
).unwrap();
console.log({ metadata })
if (!metadata || !metadata.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === "default") {
dispatch(vaeSelected(null))
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state)
const vaeArray = map(data?.entities)
const validVae = vaeArray.find(model => model.key === vae)
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: "Default settings" }) })))
},
});
};

View File

@@ -1,4 +1,5 @@
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt';
@@ -82,6 +83,8 @@ export type AppConfig = {
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler,
vaePrecision?: ParameterPrecision
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model

View File

@@ -1,21 +0,0 @@
import type { Meta, StoryObj } from '@storybook/react';
import { EmbeddingSelect } from './EmbeddingSelect';
import type { EmbeddingSelectProps } from './types';
const meta: Meta<typeof EmbeddingSelect> = {
title: 'Feature/Prompt/EmbeddingSelect',
tags: ['autodocs'],
component: EmbeddingSelect,
};
export default meta;
type Story = StoryObj<typeof EmbeddingSelect>;
const Component = (props: EmbeddingSelectProps) => {
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@@ -1,67 +0,0 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfig | null) => {
if (!embedding) {
return;
}
onSelect(embedding.name);
},
[onSelect]
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
getIsDisabled,
onChange: _onChange,
});
return (
<FormControl>
<Combobox
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={onChange}
onMenuClose={onClose}
data-testid="add-embedding"
sx={selectStyles}
/>
</FormControl>
);
});
EmbeddingSelect.displayName = 'EmbeddingSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@@ -8,7 +8,7 @@ export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
</Box>
);
};

View File

@@ -0,0 +1,66 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
return {
initialSteps: steps.initial,
initialCfg: guidance.initial,
initialScheduler: scheduler,
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
initialVaePrecision: vaePrecision,
};
});
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => {
return {
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
vaePrecision: {
isEnabled: !isNil(data?.default_settings?.vae_precision),
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
},
scheduler: {
isEnabled: !isNil(data?.default_settings?.scheduler),
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
},
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
cfgScale: {
isEnabled: !isNil(data?.default_settings?.cfg_scale),
value: data?.default_settings?.cfg_scale || initialCfg,
},
cfgRescaleMultiplier: {
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
},
};
}, [
data?.default_settings,
initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
]);
if (isLoading) {
return <Loading />;
}
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
};

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback,useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgRescaleMultiplierType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgRescaleMultiplierType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGRescaleMultiplier">
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback,useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGScale">
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,50 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterScheduler(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultSchedulerType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
[field]
);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSchedulerType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramScheduler">
<FormLabel>{t('parameters.scheduler')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,147 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
import { DefaultScheduler } from './DefaultScheduler';
import { DefaultSteps } from './DefaultSteps';
import { DefaultVae } from './DefaultVae';
import { DefaultVaePrecision } from './DefaultVaePrecision';
import { SettingToggle } from './SettingToggle';
export interface FormField<T> {
value: T;
isEnabled: boolean;
}
export type DefaultSettingsFormData = {
vae: FormField<string>;
vaePrecision: FormField<string>;
scheduler: FormField<ParameterScheduler>;
steps: FormField<number>;
cfgScale: FormField<number>;
cfgRescaleMultiplier: FormField<number>;
};
export const DefaultSettingsForm = ({
defaultSettingsDefaults,
}: {
defaultSettingsDefaults: DefaultSettingsFormData;
}) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
(data) => {
if (!selectedModelKey) {
return;
}
const body = {
vae: data.vae.isEnabled ? data.vae.value : null,
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
steps: data.steps.isEnabled ? data.steps.value : null,
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
editModelMetadata({
key: selectedModelKey,
body: { default_settings: body },
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[selectedModelKey, dispatch, editModelMetadata, t]
);
return (
<>
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
<Button
size="sm"
leftIcon={<IoPencil />}
colorScheme="invokeYellow"
isDisabled={!formState.isDirty}
onClick={handleSubmit(onSubmit)}
type="submit"
isLoading={isLoading}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={8}>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vae" />
<DefaultVae control={control} name="vae" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vaePrecision" />
<DefaultVaePrecision control={control} name="vaePrecision" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="scheduler" />
<DefaultScheduler control={control} name="scheduler" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="steps" />
<DefaultSteps control={control} name="steps" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgScale" />
<DefaultCfgScale control={control} name="cfgScale" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgRescaleMultiplier" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex>
</Flex>
</Flex>
</>
);
};

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback,useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSteps = DefaultSettingsFormData['steps'];
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultSteps),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultSteps).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSteps).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramSteps">
<FormLabel>{t('parameters.steps')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,65 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultVaeType = DefaultSettingsFormData['vae'];
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data }) => {
const modelArray = map(data?.entities);
const compatibleOptions = modelArray
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
},
});
const onChange = useCallback<ComboboxOnChange>(
(v) => {
const newValue = !v?.value ? 'default' : v.value;
const updatedValue = {
...(field.value as DefaultVaeType),
value: newValue,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
}, [compatibleOptions, field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaeType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAE">
<FormLabel>{t('modelManager.vae')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,51 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type {UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
const options = [
{ label: 'FP16', value: 'fp16' },
{ label: 'FP32', value: 'fp32' },
];
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterPrecision(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultVaePrecisionType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaePrecisionType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAEPrecision">
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,32 @@
import { Switch } from '@invoke-ai/ui-library';
import type { ChangeEvent } from 'react';
import { useCallback , useMemo } from 'react';
import type { UseControllerProps} from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
interface Props<T> extends UseControllerProps<DefaultSettingsFormData> {
name: keyof DefaultSettingsFormData;
}
export function SettingToggle<T>(props: Props<T>) {
const { field } = useController(props);
const value = useMemo(() => {
return !!(field.value as FormField<T>).isEnabled;
}, [field.value]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
const updatedValue: FormField<T> = {
...(field.value as FormField<T>),
isEnabled: e.target.checked,
};
field.onChange(updatedValue);
},
[field]
);
return <Switch isChecked={value} onChange={onChange} />;
}

View File

@@ -0,0 +1,21 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { TriggerPhrases } from './TriggerPhrases';
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
return (
<Flex flexDir="column" height="full" gap="3">
<Box layerStyle="second" borderRadius="base" p={3}>
<TriggerPhrases />
</Box>
<DataViewer label="metadata" data={metadata || {}} />
</Flex>
);
};

View File

@@ -0,0 +1,106 @@
import {
Button,
Flex,
FormControl,
FormErrorMessage,
Input,
Tag,
TagCloseButton,
TagLabel,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelMetadataQuery, useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
export const TriggerPhrases = () => {
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const [phrase, setPhrase] = useState('');
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPhrase(e.target.value);
}, []);
const triggerPhrases = useMemo(() => {
return metadata?.trigger_phrases || [];
}, [metadata?.trigger_phrases]);
const errors = useMemo(() => {
const errors = [];
if (phrase.length && triggerPhrases.includes(phrase)) {
errors.push('Phrase is already in list');
}
return errors;
}, [phrase, triggerPhrases]);
const addTriggerPhrase = useCallback(async () => {
if (!selectedModelKey) {
return;
}
if (!phrase.length || triggerPhrases.includes(phrase)) {
return;
}
await editModelMetadata({
key: selectedModelKey,
body: { trigger_phrases: [...triggerPhrases, phrase] },
}).unwrap();
setPhrase('');
}, [editModelMetadata, selectedModelKey, phrase, triggerPhrases]);
const removeTriggerPhrase = useCallback(
async (phraseToRemove: string) => {
if (!selectedModelKey) {
return;
}
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
await editModelMetadata({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
},
[editModelMetadata, selectedModelKey, triggerPhrases]
);
return (
<Flex flexDir="column" w="full" gap="5">
<ModelListHeader title={t('modelManager.triggerPhrases')} />
<form>
<FormControl w="full" isInvalid={Boolean(errors.length)}>
<Flex flexDir="column" w="full">
<Flex gap="3" alignItems="center" w="full">
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
<Button
type="submit"
onClick={addTriggerPhrase}
isDisabled={Boolean(errors.length)}
isLoading={isLoading}
>
{t('common.add')}
</Button>
</Flex>
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
</Flex>
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index}>
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
</Flex>
);
};

View File

@@ -1,9 +1,58 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView';
export const Model = () => {
const { t } = useTranslation();
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<Tabs mt="4" h="100%">
<TabList>
<Tab>{t('modelManager.settings')}</Tab>
<Tab>{t('modelManager.metadata')}</Tab>
</TabList>
<TabPanels h="100%">
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
<TabPanel h="full">
<ModelMetadata />
</TabPanel>
</TabPanels>
</Tabs>
</>
);
};

View File

@@ -1,12 +1,11 @@
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type {
CheckpointModelConfig,
ControlNetModelConfig,
@@ -18,6 +17,7 @@ import type {
VAEModelConfig,
} from 'services/api/types';
import { DefaultSettings } from './DefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { ModelConvert } from './ModelConvert';
@@ -26,7 +26,6 @@ export const ModelView = () => {
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const modelData = useMemo(() => {
if (!data) {
@@ -73,85 +72,56 @@ export const ModelView = () => {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<Flex flexDir="column" h="full">
<Flex w="full" justifyContent="space-between">
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{modelData.name}
</Heading>
{modelData.source && (
<Text variant="subtext">
{t('modelManager.source')}: {modelData.source}
</Text>
)}
</Flex>
<Flex gap={2}>
<Flex flexDir="column" h="full" gap="2">
<Box layerStyle="second" borderRadius="base" p={3}>
<Flex gap="2" justifyContent="flex-end" w="full">
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
{t('modelManager.edit')}
</Button>
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
</Flex>
</Flex>
<Flex flexDir="column" p={2} gap={3}>
<Flex>
<ModelAttrView label="Description" value={modelData.description} />
</Flex>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelSettings')}
</Heading>
<Box layerStyle="second" borderRadius="base" p={3}>
<Flex flexDir="column" gap={3}>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex flexDir="column" gap={3}>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box>
</Flex>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
{metadata && (
<>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelMetadata')}
</Heading>
<Flex h="full" w="full" p={2}>
<DataViewer label="metadata" data={metadata} />
</Flex>
</>
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box>
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
</Flex>
);
};

View File

@@ -1,10 +1,10 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
prompt,
textareaRef,
onChange: _onChange,
});
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="negativePrompt"
@@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@@ -1,11 +1,11 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
@@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
@@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
useHotkeys('alt+a', focus, []);
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@@ -0,0 +1,28 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setDefaultSettings } from 'features/parameters/store/actions';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSparklingFill } from 'react-icons/ri';
export const UseDefaultSettingsButton = () => {
const model = useAppSelector((s) => s.generation.model);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleClickDefaultSettings = useCallback(() => {
dispatch(setDefaultSettings());
}, [dispatch]);
return (
<IconButton
icon={<RiSparklingFill />}
tooltip={t('modelManager.useDefaultSettings')}
aria-label={t('modelManager.useDefaultSettings')}
isDisabled={!model}
onClick={handleClickDefaultSettings}
size="sm"
variant="ghost"
/>
);
};

View File

@@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
export const setDefaultSettings = createAction('generation/setDefaultSettings');

View File

@@ -230,6 +230,12 @@ export const generationSlice = createSlice({
state.height = optimalDimension;
}
}
if (action.payload.sd?.scheduler) {
state.scheduler = action.payload.sd.scheduler;
}
if (action.payload.sd?.vaePrecision) {
state.vaePrecision = action.payload.sd.vaePrecision;
}
});
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling

View File

@@ -8,15 +8,15 @@ type Props = {
onOpen: () => void;
};
export const AddEmbeddingButton = memo((props: Props) => {
export const AddPromptTriggerButton = memo((props: Props) => {
const { onOpen, isOpen } = props;
const { t } = useTranslation();
return (
<Tooltip label={t('embedding.addEmbedding')}>
<Tooltip label={t('prompt.addPromptTrigger')}>
<IconButton
variant="promptOverlay"
isDisabled={isOpen}
aria-label={t('embedding.addEmbedding')}
aria-label={t('prompt.addPromptTrigger')}
icon={<PiCodeBold />}
onClick={onOpen}
/>
@@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
);
});
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';

View File

@@ -1,9 +1,9 @@
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
import type { EmbeddingPopoverProps } from 'features/embedding/types';
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
import type { PromptPopoverProps } from 'features/prompt/types';
import { memo } from 'react';
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
export const PromptPopover = memo((props: PromptPopoverProps) => {
const { onSelect, isOpen, onClose, width, children } = props;
return (
@@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
returnFocusOnClose={false}
isLazy
>
<PopoverAnchor>{children}</PopoverAnchor>
@@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
borderStyle="solid"
>
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
</PopoverBody>
</PopoverContent>
</Popover>
);
});
EmbeddingPopover.displayName = 'EmbeddingPopover';
PromptPopover.displayName = 'PromptPopover';

View File

@@ -0,0 +1,21 @@
import type { Meta, StoryObj } from '@storybook/react';
import { PromptTriggerSelect } from './PromptTriggerSelect';
import type { PromptTriggerSelectProps } from './types';
const meta: Meta<typeof PromptTriggerSelect> = {
title: 'Feature/Prompt/PromptTriggerSelect',
tags: ['autodocs'],
component: PromptTriggerSelect,
};
export default meta;
type Story = StoryObj<typeof PromptTriggerSelect>;
const Component = (props: PromptTriggerSelectProps) => {
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@@ -0,0 +1,86 @@
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
onSelect('');
return;
}
onSelect(v.value);
},
[onSelect]
);
const embeddingOptions = useMemo(() => {
if (!data) {
return [];
}
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
return [
{
label: t('prompt.compatibleEmbeddings'),
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
},
];
}, [data, currentBaseModel, t]);
const options = useMemo(() => {
if (!metadata || !metadata.trigger_phrases) {
return [...embeddingOptions];
}
const metadataOptions = [
{
label: t('modelManager.triggerPhrases'),
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
},
];
return [...metadataOptions, ...embeddingOptions];
}, [embeddingOptions, metadata, t]);
return (
<FormControl>
<Combobox
placeholder={isLoading ? t('common.loading') : t('prompt.addPromptTrigger')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={_onChange}
onMenuClose={onClose}
data-testid="add-prompt-trigger"
sx={selectStyles}
/>
</FormControl>
);
});
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@@ -1,12 +1,12 @@
import type { PropsWithChildren } from 'react';
export type EmbeddingSelectProps = {
export type PromptTriggerSelectProps = {
onSelect: (v: string) => void;
onClose: () => void;
};
export type EmbeddingPopoverProps = PropsWithChildren &
EmbeddingSelectProps & {
export type PromptPopoverProps = PropsWithChildren &
PromptTriggerSelectProps & {
isOpen: boolean;
width?: number | string;
};

View File

@@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
import { useCallback } from 'react';
import { flushSync } from 'react-dom';
type UseInsertEmbeddingArg = {
type UseInsertTriggerArg = {
prompt: string;
textareaRef: RefObject<HTMLTextAreaElement>;
onChange: (v: string) => void;
};
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
const { isOpen, onClose, onOpen } = useDisclosure();
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
@@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
[_onChange]
);
const insertEmbedding = useCallback(
const insertTrigger = useCallback(
(v: string) => {
if (!textareaRef.current) {
return;
}
// this is where we insert the TI trigger
// this is where we insert the trigger
const caret = textareaRef.current.selectionStart;
if (isNil(caret)) {
@@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}`;
newPrompt += `${v}>`;
// we insert the cursor after the `>`
// we insert the cursor after the end of trigger
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
@@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
_onChange(newPrompt);
});
// set the caret position to just after the TI trigger
// set the cursor position to just after the trigger
textareaRef.current.selectionStart = finalCaretPos;
textareaRef.current.selectionEnd = finalCaretPos;
},
@@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
textareaRef.current?.focus();
}, [textareaRef]);
const handleClose = useCallback(() => {
const handleClosePopover = useCallback(() => {
onClose();
onFocus();
}, [onFocus, onClose]);
const onSelectEmbedding = useCallback(
const onSelect = useCallback(
(v: string) => {
insertEmbedding(v);
handleClose();
insertTrigger(v);
handleClosePopover();
},
[handleClose, insertEmbedding]
[handleClosePopover, insertTrigger]
);
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
@@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
isOpen,
onClose,
onOpen,
onSelectEmbedding,
onSelect,
onKeyDown,
onFocus,
};

View File

@@ -1,9 +1,9 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
@@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
@@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
useHotkeys('alt+a', onFocus, []);
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@@ -1,9 +1,9 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
});
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@@ -21,6 +21,7 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { filter } from 'lodash-es';
@@ -71,7 +72,10 @@ export const GenerationSettingsAccordion = memo(() => {
<TabPanel overflow="visible" px={4} pt={4}>
<Flex gap={4} alignItems="center">
<ParamMainModelSelect />
<SyncModelsIconButton />
<Flex>
<UseDefaultSettingsButton />
<SyncModelsIconButton />
</Flex>
</Flex>
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>

View File

@@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
boundingBoxHeight: { ...baseDimensionConfig },
scaledBoundingBoxWidth: { ...baseDimensionConfig },
scaledBoundingBoxHeight: { ...baseDimensionConfig },
scheduler: "euler",
vaePrecision: "fp32",
steps: {
initial: 30,
sliderMin: 1,

View File

@@ -24,12 +24,21 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type DeleteMainModelArg = {
@@ -108,25 +117,25 @@ const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefi
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
/**
* Builds an endpoint URL for the models router
@@ -172,6 +181,16 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => {
return {
@@ -351,6 +370,7 @@ export const {
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
} = modelsApi;
const upsertModelConfigs = (

File diff suppressed because one or more lines are too long