mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 23:18:08 -05:00
Compare commits
13 Commits
next-test-
...
maryhipp/d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0bfa7f0e0 | ||
|
|
c46b2b6fa6 | ||
|
|
058cc715d4 | ||
|
|
f69e3ee01c | ||
|
|
6e0665e3d7 | ||
|
|
5a35550144 | ||
|
|
8926a1a424 | ||
|
|
8566c1c7ff | ||
|
|
6eb4c1ccb6 | ||
|
|
ef474a3196 | ||
|
|
16b3718d6a | ||
|
|
30228ce2a4 | ||
|
|
efb5f2d202 |
@@ -5,6 +5,7 @@ import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@@ -242,6 +245,47 @@ async def get_model_metadata(
|
||||
|
||||
return result
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes")
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.trigger_phrases:
|
||||
original_metadata.trigger_phrases = changes.trigger_phrases
|
||||
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases, default_settings=changes.default_settings))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
|
||||
@@ -4,9 +4,27 @@ Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `trigger_phrases`: the list of trigger phrases for this model
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
"""
|
||||
|
||||
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||
"""The model's list of trigger phrases"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(default=None, description="The user-configured default settings for this model")
|
||||
"""The user-configured default settings for this model"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
|
||||
@@ -115,6 +115,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
@@ -179,44 +181,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
||||
@@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=version_json["trainedWords"],
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
|
||||
@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
@@ -68,12 +69,22 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(description="default settings for this model", default=None)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
|
||||
@@ -78,6 +78,7 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"advancedOptions": "Advanced Options",
|
||||
"ai": "ai",
|
||||
@@ -303,6 +304,12 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
@@ -734,6 +741,8 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@@ -768,6 +777,7 @@
|
||||
"mergedModelName": "Merged Model Name",
|
||||
"mergedModelSaveLocation": "Save Location",
|
||||
"mergeModels": "Merge Models",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelAdded": "Model Added",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@@ -839,9 +849,12 @@
|
||||
"statusConverting": "Converting",
|
||||
"syncModels": "Sync Models",
|
||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
|
||||
@@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening)
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { setCfgRescaleMultiplier, setCfgScale, setScheduler, setSteps, vaePrecisionChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterPrecision, isParameterScheduler, isParameterSteps, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return
|
||||
}
|
||||
|
||||
const metadata = await dispatch(
|
||||
modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)
|
||||
).unwrap();
|
||||
|
||||
console.log({ metadata })
|
||||
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === "default") {
|
||||
dispatch(vaeSelected(null))
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state)
|
||||
const vaeArray = map(data?.entities)
|
||||
const validVae = vaeArray.find(model => model.key === vae)
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: "Default settings" }) })))
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@@ -82,6 +83,8 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler,
|
||||
vaePrecision?: ParameterPrecision
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
||||
@@ -1,67 +0,0 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
||||
@@ -8,7 +8,7 @@ export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
};
|
||||
});
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
data?.default_settings,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
};
|
||||
@@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGScale">
|
||||
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||
|
||||
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSchedulerType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(
|
||||
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||
[field]
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramScheduler">
|
||||
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
import { DefaultScheduler } from './DefaultScheduler';
|
||||
import { DefaultSteps } from './DefaultSteps';
|
||||
import { DefaultVae } from './DefaultVae';
|
||||
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||
import { SettingToggle } from './SettingToggle';
|
||||
|
||||
export interface FormField<T> {
|
||||
value: T;
|
||||
isEnabled: boolean;
|
||||
}
|
||||
|
||||
export type DefaultSettingsFormData = {
|
||||
vae: FormField<string>;
|
||||
vaePrecision: FormField<string>;
|
||||
scheduler: FormField<ParameterScheduler>;
|
||||
steps: FormField<number>;
|
||||
cfgScale: FormField<number>;
|
||||
cfgRescaleMultiplier: FormField<number>;
|
||||
};
|
||||
|
||||
export const DefaultSettingsForm = ({
|
||||
defaultSettingsDefaults,
|
||||
}: {
|
||||
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||
}) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.defaultSettingsSaved'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
leftIcon={<IoPencil />}
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vae" />
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vaePrecision" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="scheduler" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="steps" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgScale" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex,FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback,useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||
|
||||
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSteps),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultSteps).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSteps).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramSteps">
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||
|
||||
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
const newValue = !v?.value ? 'default' : v.value;
|
||||
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaeType),
|
||||
value: newValue,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||
}, [compatibleOptions, field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaeType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAE">
|
||||
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type {UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
const options = [
|
||||
{ label: 'FP16', value: 'fp16' },
|
||||
{ label: 'FP32', value: 'fp32' },
|
||||
];
|
||||
|
||||
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterPrecision(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaePrecisionType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAEPrecision">
|
||||
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback , useMemo } from 'react';
|
||||
import type { UseControllerProps} from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
|
||||
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||
|
||||
interface Props<T> extends UseControllerProps<DefaultSettingsFormData> {
|
||||
name: keyof DefaultSettingsFormData;
|
||||
}
|
||||
|
||||
export function SettingToggle<T>(props: Props<T>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return !!(field.value as FormField<T>).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const updatedValue: FormField<T> = {
|
||||
...(field.value as FormField<T>),
|
||||
isEnabled: e.target.checked,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return <Switch isChecked={value} onChange={onChange} />;
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { TriggerPhrases } from './TriggerPhrases';
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
<DataViewer label="metadata" data={metadata || {}} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,106 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
Input,
|
||||
Tag,
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPhrase(e.target.value);
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
return metadata?.trigger_phrases || [];
|
||||
}, [metadata?.trigger_phrases]);
|
||||
|
||||
const errors = useMemo(() => {
|
||||
const errors = [];
|
||||
|
||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
||||
errors.push('Phrase is already in list');
|
||||
}
|
||||
|
||||
return errors;
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
setPhrase('');
|
||||
}, [editModelMetadata, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await editModelMetadata({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[editModelMetadata, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
||||
<form>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap="3" alignItems="center" w="full">
|
||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index}>
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,9 +1,58 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.settings')}</Tab>
|
||||
<Tab>{t('modelManager.metadata')}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="100%">
|
||||
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
|
||||
<TabPanel h="full">
|
||||
<ModelMetadata />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@@ -18,6 +17,7 @@ import type {
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { DefaultSettings } from './DefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelConvert } from './ModelConvert';
|
||||
|
||||
@@ -26,7 +26,6 @@ export const ModelView = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
@@ -73,85 +72,56 @@ export const ModelView = () => {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex w="full" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelData.name}
|
||||
</Heading>
|
||||
|
||||
{modelData.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {modelData.source}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<Flex flexDir="column" h="full" gap="2">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" p={2} gap={3}>
|
||||
<Flex>
|
||||
<ModelAttrView label="Description" value={modelData.description} />
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</Flex>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
{metadata && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelMetadata')}
|
||||
</Heading>
|
||||
<Flex h="full" w="full" p={2}>
|
||||
<DataViewer label="metadata" data={metadata} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="negativePrompt"
|
||||
@@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
@@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
useHotkeys('alt+a', focus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
||||
<ShowDynamicPromptsPreviewButton />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
|
||||
export const UseDefaultSettingsButton = () => {
|
||||
const model = useAppSelector((s) => s.generation.model);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDefaultSettings = useCallback(() => {
|
||||
dispatch(setDefaultSettings());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<RiSparklingFill />}
|
||||
tooltip={t('modelManager.useDefaultSettings')}
|
||||
aria-label={t('modelManager.useDefaultSettings')}
|
||||
isDisabled={!model}
|
||||
onClick={handleClickDefaultSettings}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
||||
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
||||
@@ -230,6 +230,12 @@ export const generationSlice = createSlice({
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
if (action.payload.sd?.vaePrecision) {
|
||||
state.vaePrecision = action.payload.sd.vaePrecision;
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
||||
|
||||
@@ -8,15 +8,15 @@ type Props = {
|
||||
onOpen: () => void;
|
||||
};
|
||||
|
||||
export const AddEmbeddingButton = memo((props: Props) => {
|
||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
const { onOpen, isOpen } = props;
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('embedding.addEmbedding')}>
|
||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
isDisabled={isOpen}
|
||||
aria-label={t('embedding.addEmbedding')}
|
||||
aria-label={t('prompt.addPromptTrigger')}
|
||||
icon={<PiCodeBold />}
|
||||
onClick={onOpen}
|
||||
/>
|
||||
@@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
const { onSelect, isOpen, onClose, width, children } = props;
|
||||
|
||||
return (
|
||||
@@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
returnFocusOnClose={false}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>{children}</PopoverAnchor>
|
||||
@@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
||||
PromptPopover.displayName = 'PromptPopover';
|
||||
@@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { PromptTriggerSelect } from './PromptTriggerSelect';
|
||||
import type { PromptTriggerSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof PromptTriggerSelect> = {
|
||||
title: 'Feature/Prompt/PromptTriggerSelect',
|
||||
tags: ['autodocs'],
|
||||
component: PromptTriggerSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof PromptTriggerSelect>;
|
||||
|
||||
const Component = (props: PromptTriggerSelectProps) => {
|
||||
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
||||
@@ -0,0 +1,86 @@
|
||||
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
|
||||
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
onSelect('');
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v.value);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const embeddingOptions = useMemo(() => {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
|
||||
return [
|
||||
{
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
}, [data, currentBaseModel, t]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
if (!metadata || !metadata.trigger_phrases) {
|
||||
return [...embeddingOptions];
|
||||
}
|
||||
|
||||
const metadataOptions = [
|
||||
{
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
},
|
||||
];
|
||||
return [...metadataOptions, ...embeddingOptions];
|
||||
}, [embeddingOptions, metadata, t]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={_onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-prompt-trigger"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
import type { PropsWithChildren } from 'react';
|
||||
|
||||
export type EmbeddingSelectProps = {
|
||||
export type PromptTriggerSelectProps = {
|
||||
onSelect: (v: string) => void;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||
EmbeddingSelectProps & {
|
||||
export type PromptPopoverProps = PropsWithChildren &
|
||||
PromptTriggerSelectProps & {
|
||||
isOpen: boolean;
|
||||
width?: number | string;
|
||||
};
|
||||
@@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
||||
import { useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
type UseInsertEmbeddingArg = {
|
||||
type UseInsertTriggerArg = {
|
||||
prompt: string;
|
||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||
onChange: (v: string) => void;
|
||||
};
|
||||
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
[_onChange]
|
||||
);
|
||||
|
||||
const insertEmbedding = useCallback(
|
||||
const insertTrigger = useCallback(
|
||||
(v: string) => {
|
||||
if (!textareaRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
// this is where we insert the trigger
|
||||
const caret = textareaRef.current.selectionStart;
|
||||
|
||||
if (isNil(caret)) {
|
||||
@@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
newPrompt += `${v}`;
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
// we insert the cursor after the end of trigger
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
@@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
_onChange(newPrompt);
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
// set the cursor position to just after the trigger
|
||||
textareaRef.current.selectionStart = finalCaretPos;
|
||||
textareaRef.current.selectionEnd = finalCaretPos;
|
||||
},
|
||||
@@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
textareaRef.current?.focus();
|
||||
}, [textareaRef]);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
const handleClosePopover = useCallback(() => {
|
||||
onClose();
|
||||
onFocus();
|
||||
}, [onFocus, onClose]);
|
||||
|
||||
const onSelectEmbedding = useCallback(
|
||||
const onSelect = useCallback(
|
||||
(v: string) => {
|
||||
insertEmbedding(v);
|
||||
handleClose();
|
||||
insertTrigger(v);
|
||||
handleClosePopover();
|
||||
},
|
||||
[handleClose, insertEmbedding]
|
||||
[handleClosePopover, insertTrigger]
|
||||
);
|
||||
|
||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
isOpen,
|
||||
onClose,
|
||||
onOpen,
|
||||
onSelectEmbedding,
|
||||
onSelect,
|
||||
onKeyDown,
|
||||
onFocus,
|
||||
};
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
useHotkeys('alt+a', onFocus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
@@ -71,7 +72,10 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<TabPanel overflow="visible" px={4} pt={4}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<SyncModelsIconButton />
|
||||
<Flex>
|
||||
<UseDefaultSettingsButton />
|
||||
<SyncModelsIconButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
|
||||
@@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
|
||||
boundingBoxHeight: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||
scheduler: "euler",
|
||||
vaePrecision: "fp32",
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
|
||||
@@ -24,12 +24,21 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelMetadataArg = {
|
||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
@@ -108,25 +117,25 @@ const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefi
|
||||
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the models router
|
||||
@@ -172,6 +181,16 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/metadata`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
@@ -351,6 +370,7 @@ export const {
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
usePruneModelImportsMutation,
|
||||
useUpdateModelMetadataMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user