mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui,api): add starter bundles to MM
This commit is contained in:
committed by
Mary Hipp Rogers
parent
fe87c198eb
commit
5bd87ca89b
@@ -38,7 +38,12 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Cac
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
STARTER_BUNDLES,
|
||||
STARTER_MODELS,
|
||||
StarterModel,
|
||||
StarterModelWithoutDependencies,
|
||||
)
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
@@ -792,11 +797,17 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=list[StarterModel])
|
||||
async def get_starter_models() -> list[StarterModel]:
|
||||
class StarterModelResponse(BaseModel):
|
||||
starter_models: list[StarterModel]
|
||||
starter_bundles: dict[str, list[StarterModel]]
|
||||
|
||||
|
||||
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=StarterModelResponse)
|
||||
async def get_starter_models() -> StarterModelResponse:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
installed_model_sources = {m.source for m in installed_models}
|
||||
starter_models = deepcopy(STARTER_MODELS)
|
||||
starter_bundles = deepcopy(STARTER_BUNDLES)
|
||||
for model in starter_models:
|
||||
if model.source in installed_model_sources:
|
||||
model.is_installed = True
|
||||
@@ -807,7 +818,18 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
missing_deps.append(dep)
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
for bundle in starter_bundles.values():
|
||||
for model in bundle:
|
||||
if model.source in installed_model_sources:
|
||||
model.is_installed = True
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[StarterModelWithoutDependencies] = []
|
||||
for dep in model.dependencies or []:
|
||||
if dep.source not in installed_model_sources:
|
||||
missing_deps.append(dep)
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -728,6 +728,7 @@
|
||||
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
|
||||
"hfToken": "HuggingFace Token",
|
||||
"imageEncoderModelId": "Image Encoder Model ID",
|
||||
"includesNModels": "Includes {{n}} models and their dependencies",
|
||||
"installQueue": "Install Queue",
|
||||
"inplaceInstall": "In-place install",
|
||||
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
|
||||
@@ -781,6 +782,7 @@
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterBundles": "Starter Bundles",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
@@ -1990,7 +1992,7 @@
|
||||
}
|
||||
},
|
||||
"newUserExperience": {
|
||||
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
|
||||
"toGetStarted": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
|
||||
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio."
|
||||
},
|
||||
"whatsNew": {
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import { EMPTY_ARRAY } from "app/store/constants";
|
||||
import { useCallback,useMemo } from "react";
|
||||
import { modelConfigsAdapterSelectors,useGetModelConfigsQuery } from "services/api/endpoints/models";
|
||||
import type { StarterModel } from "services/api/types";
|
||||
|
||||
export const useBuildModelsToInstall = () => {
|
||||
const { data: modelListRes } = useGetModelConfigsQuery();
|
||||
const modelList = useMemo(() => {
|
||||
if (!modelListRes) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(modelListRes);
|
||||
}, [modelListRes]);
|
||||
|
||||
const buildModelToInstall = useCallback(
|
||||
(starterModel: StarterModel) => {
|
||||
if (
|
||||
modelList.some(
|
||||
(mc) => starterModel.base === mc.base && starterModel.name === mc.name && starterModel.type === mc.type
|
||||
)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const source = starterModel.source;
|
||||
const config = {
|
||||
name: starterModel.name,
|
||||
description: starterModel.description,
|
||||
type: starterModel.type,
|
||||
base: starterModel.base,
|
||||
format: starterModel.format,
|
||||
};
|
||||
return { config, source };
|
||||
},
|
||||
[modelList]
|
||||
);
|
||||
|
||||
return buildModelToInstall
|
||||
}
|
||||
@@ -1,64 +1,58 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useBuildModelsToInstall } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
result: GetStarterModelsResponse[number];
|
||||
modelList: AnyModelConfig[];
|
||||
model: GetStarterModelsResponse['starter_models'][number];
|
||||
};
|
||||
export const StarterModelsResultItem = memo(({ result, modelList }: Props) => {
|
||||
export const StarterModelsResultItem = memo(({ model }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const buildModelToInstall = useBuildModelsToInstall();
|
||||
|
||||
const allSources = useMemo(() => {
|
||||
const _allSources = [
|
||||
{
|
||||
source: result.source,
|
||||
config: {
|
||||
name: result.name,
|
||||
description: result.description,
|
||||
type: result.type,
|
||||
base: result.base,
|
||||
format: result.format,
|
||||
},
|
||||
},
|
||||
];
|
||||
if (result.dependencies) {
|
||||
for (const d of result.dependencies) {
|
||||
_allSources.push({
|
||||
source: d.source,
|
||||
config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format },
|
||||
});
|
||||
const _allSources = [];
|
||||
|
||||
const result = buildModelToInstall(model);
|
||||
if (result) {
|
||||
_allSources.push(result);
|
||||
}
|
||||
|
||||
if (model.dependencies) {
|
||||
for (const d of model.dependencies) {
|
||||
const result = buildModelToInstall(d);
|
||||
if (result) {
|
||||
_allSources.push(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return _allSources;
|
||||
}, [result]);
|
||||
}, [model, buildModelToInstall]);
|
||||
|
||||
const [installModel] = useInstallModel();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
for (const { config, source } of allSources) {
|
||||
if (modelList.some((mc) => config.base === mc.base && config.name === mc.name && config.type === mc.type)) {
|
||||
continue;
|
||||
}
|
||||
installModel({ config, source });
|
||||
for (const model of allSources) {
|
||||
installModel(model);
|
||||
}
|
||||
}, [modelList, allSources, installModel]);
|
||||
}, [allSources, installModel]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Flex gap={3}>
|
||||
<Badge h="min-content">{result.type.replaceAll('_', ' ')}</Badge>
|
||||
<ModelBaseBadge base={result.base} />
|
||||
<Text fontWeight="semibold">{result.name}</Text>
|
||||
<Badge h="min-content">{model.type.replaceAll('_', ' ')}</Badge>
|
||||
<ModelBaseBadge base={model.base} />
|
||||
<Text fontWeight="semibold">{model.name}</Text>
|
||||
</Flex>
|
||||
<Text variant="subtext">{result.description}</Text>
|
||||
<Text variant="subtext">{model.description}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
{result.is_installed ? (
|
||||
{model.is_installed ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import { Button, Flex, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useBuildModelsToInstall } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { type GetStarterModelsResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import { isMainModelBase } from '../../../../nodes/types/common';
|
||||
|
||||
export const StarterBundle = ({
|
||||
bundleName,
|
||||
bundle,
|
||||
}: {
|
||||
bundleName: string;
|
||||
bundle: GetStarterModelsResponse['starter_bundles'][number];
|
||||
}) => {
|
||||
const [installModel] = useInstallModelMutation();
|
||||
const buildModelToInstall = useBuildModelsToInstall();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const modelsToInstall = useMemo(() => {
|
||||
const _modelsToInstall = [];
|
||||
const _modelsToSkip = [];
|
||||
for (let index = 0; index < bundle.length; index++) {
|
||||
const starterModel = bundle[index];
|
||||
if (!starterModel) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const result = buildModelToInstall(starterModel);
|
||||
if (result) {
|
||||
_modelsToInstall.push(result);
|
||||
} else {
|
||||
_modelsToSkip.push(result);
|
||||
}
|
||||
|
||||
if (starterModel.dependencies) {
|
||||
for (const d of starterModel.dependencies) {
|
||||
const result = buildModelToInstall(d);
|
||||
if (result) {
|
||||
_modelsToInstall.push(result);
|
||||
} else {
|
||||
_modelsToSkip.push(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { install: _modelsToInstall, skip: _modelsToSkip };
|
||||
}, [bundle, buildModelToInstall]);
|
||||
|
||||
const handleClickBundle = useCallback(async () => {
|
||||
for (let index = 0; index < modelsToInstall.install.length; index++) {
|
||||
const model = modelsToInstall.install[index];
|
||||
if (model) {
|
||||
await installModel(model).unwrap();
|
||||
}
|
||||
}
|
||||
toast({
|
||||
status: 'info',
|
||||
title: 'Bundle Installing',
|
||||
description: `Installing ${modelsToInstall.install.length}, skipping ${modelsToInstall.skip.length} duplicates`,
|
||||
});
|
||||
}, [modelsToInstall, installModel]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
<Text>{t('modelManager.includesNModels', { n: bundle.length })}</Text>
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
<Button flexDir="column" size="sm" onClick={handleClickBundle}>
|
||||
{isMainModelBase(bundleName) && MODEL_TYPE_SHORT_MAP[bundleName]}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
@@ -1,31 +1,19 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||
import { memo, useMemo } from 'react';
|
||||
import {
|
||||
modelConfigsAdapterSelectors,
|
||||
useGetModelConfigsQuery,
|
||||
useGetStarterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResults } from './StarterModelsResults';
|
||||
|
||||
export const StarterModelsForm = memo(() => {
|
||||
const { isLoading, data } = useGetStarterModelsQuery();
|
||||
const { data: modelListRes } = useGetModelConfigsQuery();
|
||||
|
||||
const modelList = useMemo(() => {
|
||||
if (!modelListRes) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(modelListRes);
|
||||
}, [modelListRes]);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%" gap={3}>
|
||||
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||
{data && <StarterModelsResults results={data} modelList={modelList} />}
|
||||
{isLoading && <FetchingModelsLoader loadingMessage={t('common.loading')} />}
|
||||
{data && <StarterModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { StarterBundle } from './StarterBundle';
|
||||
import { StarterModelsResultItem } from './StartModelsResultItem';
|
||||
|
||||
type StarterModelsResultsProps = {
|
||||
results: NonNullable<GetStarterModelsResponse>;
|
||||
modelList: AnyModelConfig[];
|
||||
};
|
||||
|
||||
export const StarterModelsResults = memo(({ results, modelList }: StarterModelsResultsProps) => {
|
||||
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.filter((result) => {
|
||||
return results.starter_models.filter((result) => {
|
||||
const trimmedSearchTerm = searchTerm.trim().toLowerCase();
|
||||
const matchStrings = [
|
||||
result.name.toLowerCase(),
|
||||
@@ -46,7 +45,21 @@ export const StarterModelsResults = memo(({ results, modelList }: StarterModelsR
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} height="100%">
|
||||
<Flex justifyContent="flex-end" alignItems="center">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
{!!Object.keys(results.starter_bundles).length && (
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text fontWeight="semibold">{t('modelManager.starterBundles')}:</Text>
|
||||
{Object.keys(results.starter_bundles).map((bundleName) => (
|
||||
<>
|
||||
{results.starter_bundles[bundleName] ? (
|
||||
<StarterBundle bundleName={bundleName} bundle={results.starter_bundles[bundleName]} />
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
<InputGroup w={64} size="xs">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
@@ -74,7 +87,7 @@ export const StarterModelsResults = memo(({ results, modelList }: StarterModelsR
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{filteredResults.map((result) => (
|
||||
<StarterModelsResultItem key={result.source} result={result} modelList={modelList} />
|
||||
<StarterModelsResultItem key={result.source} model={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
||||
@@ -15166,6 +15166,15 @@ export type components = {
|
||||
/** Dependencies */
|
||||
dependencies?: components["schemas"]["StarterModelWithoutDependencies"][] | null;
|
||||
};
|
||||
/** StarterModelResponse */
|
||||
StarterModelResponse: {
|
||||
/** Starter Models */
|
||||
starter_models: components["schemas"]["StarterModel"][];
|
||||
/** Starter Bundles */
|
||||
starter_bundles: {
|
||||
[key: string]: components["schemas"]["StarterModel"][];
|
||||
};
|
||||
};
|
||||
/** StarterModelWithoutDependencies */
|
||||
StarterModelWithoutDependencies: {
|
||||
/** Description */
|
||||
@@ -17972,7 +17981,7 @@ export interface operations {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["StarterModel"][];
|
||||
"application/json": components["schemas"]["StarterModelResponse"];
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -243,3 +243,4 @@ export type PostUploadAction =
|
||||
| ReplaceLayerWithImagePostUploadAction;
|
||||
|
||||
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
|
||||
export type StarterModel = S['StarterModel'];
|
||||
|
||||
Reference in New Issue
Block a user