feat(ui,api): add starter bundles to MM

This commit is contained in:
Mary Hipp
2024-10-10 20:05:03 -04:00
committed by Mary Hipp Rogers
parent fe87c198eb
commit 5bd87ca89b
10 changed files with 774 additions and 508 deletions

View File

@@ -38,7 +38,12 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Cac
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies
from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelWithoutDependencies,
)
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
@@ -792,11 +797,17 @@ async def convert_model(
return new_config
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=list[StarterModel])
async def get_starter_models() -> list[StarterModel]:
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, list[StarterModel]]
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=StarterModelResponse)
async def get_starter_models() -> StarterModelResponse:
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
installed_model_sources = {m.source for m in installed_models}
starter_models = deepcopy(STARTER_MODELS)
starter_bundles = deepcopy(STARTER_BUNDLES)
for model in starter_models:
if model.source in installed_model_sources:
model.is_installed = True
@@ -807,7 +818,18 @@ async def get_starter_models() -> list[StarterModel]:
missing_deps.append(dep)
model.dependencies = missing_deps
return starter_models
for bundle in starter_bundles.values():
for model in bundle:
if model.source in installed_model_sources:
model.is_installed = True
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []
for dep in model.dependencies or []:
if dep.source not in installed_model_sources:
missing_deps.append(dep)
model.dependencies = missing_deps
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
@model_manager_router.get(

File diff suppressed because it is too large Load Diff

View File

@@ -728,6 +728,7 @@
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
"hfToken": "HuggingFace Token",
"imageEncoderModelId": "Image Encoder Model ID",
"includesNModels": "Includes {{n}} models and their dependencies",
"installQueue": "Install Queue",
"inplaceInstall": "In-place install",
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
@@ -781,6 +782,7 @@
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterBundles": "Starter Bundles",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
@@ -1990,7 +1992,7 @@
}
},
"newUserExperience": {
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"toGetStarted": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio."
},
"whatsNew": {

View File

@@ -0,0 +1,40 @@
import { EMPTY_ARRAY } from "app/store/constants";
import { useCallback,useMemo } from "react";
import { modelConfigsAdapterSelectors,useGetModelConfigsQuery } from "services/api/endpoints/models";
import type { StarterModel } from "services/api/types";
export const useBuildModelsToInstall = () => {
const { data: modelListRes } = useGetModelConfigsQuery();
const modelList = useMemo(() => {
if (!modelListRes) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(modelListRes);
}, [modelListRes]);
const buildModelToInstall = useCallback(
(starterModel: StarterModel) => {
if (
modelList.some(
(mc) => starterModel.base === mc.base && starterModel.name === mc.name && starterModel.type === mc.type
)
) {
return undefined;
}
const source = starterModel.source;
const config = {
name: starterModel.name,
description: starterModel.description,
type: starterModel.type,
base: starterModel.base,
format: starterModel.format,
};
return { config, source };
},
[modelList]
);
return buildModelToInstall
}

View File

@@ -1,64 +1,58 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useBuildModelsToInstall } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
result: GetStarterModelsResponse[number];
modelList: AnyModelConfig[];
model: GetStarterModelsResponse['starter_models'][number];
};
export const StarterModelsResultItem = memo(({ result, modelList }: Props) => {
export const StarterModelsResultItem = memo(({ model }: Props) => {
const { t } = useTranslation();
const buildModelToInstall = useBuildModelsToInstall();
const allSources = useMemo(() => {
const _allSources = [
{
source: result.source,
config: {
name: result.name,
description: result.description,
type: result.type,
base: result.base,
format: result.format,
},
},
];
if (result.dependencies) {
for (const d of result.dependencies) {
_allSources.push({
source: d.source,
config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format },
});
const _allSources = [];
const result = buildModelToInstall(model);
if (result) {
_allSources.push(result);
}
if (model.dependencies) {
for (const d of model.dependencies) {
const result = buildModelToInstall(d);
if (result) {
_allSources.push(result);
}
}
}
return _allSources;
}, [result]);
}, [model, buildModelToInstall]);
const [installModel] = useInstallModel();
const onClick = useCallback(() => {
for (const { config, source } of allSources) {
if (modelList.some((mc) => config.base === mc.base && config.name === mc.name && config.type === mc.type)) {
continue;
}
installModel({ config, source });
for (const model of allSources) {
installModel(model);
}
}, [modelList, allSources, installModel]);
}, [allSources, installModel]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex fontSize="sm" flexDir="column">
<Flex gap={3}>
<Badge h="min-content">{result.type.replaceAll('_', ' ')}</Badge>
<ModelBaseBadge base={result.base} />
<Text fontWeight="semibold">{result.name}</Text>
<Badge h="min-content">{model.type.replaceAll('_', ' ')}</Badge>
<ModelBaseBadge base={model.base} />
<Text fontWeight="semibold">{model.name}</Text>
</Flex>
<Text variant="subtext">{result.description}</Text>
<Text variant="subtext">{model.description}</Text>
</Flex>
<Box>
{result.is_installed ? (
{model.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />

View File

@@ -0,0 +1,79 @@
import { Button, Flex, Text, Tooltip } from '@invoke-ai/ui-library';
import { useBuildModelsToInstall } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { type GetStarterModelsResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { isMainModelBase } from '../../../../nodes/types/common';
export const StarterBundle = ({
bundleName,
bundle,
}: {
bundleName: string;
bundle: GetStarterModelsResponse['starter_bundles'][number];
}) => {
const [installModel] = useInstallModelMutation();
const buildModelToInstall = useBuildModelsToInstall();
const { t } = useTranslation();
const modelsToInstall = useMemo(() => {
const _modelsToInstall = [];
const _modelsToSkip = [];
for (let index = 0; index < bundle.length; index++) {
const starterModel = bundle[index];
if (!starterModel) {
continue;
}
const result = buildModelToInstall(starterModel);
if (result) {
_modelsToInstall.push(result);
} else {
_modelsToSkip.push(result);
}
if (starterModel.dependencies) {
for (const d of starterModel.dependencies) {
const result = buildModelToInstall(d);
if (result) {
_modelsToInstall.push(result);
} else {
_modelsToSkip.push(result);
}
}
}
}
return { install: _modelsToInstall, skip: _modelsToSkip };
}, [bundle, buildModelToInstall]);
const handleClickBundle = useCallback(async () => {
for (let index = 0; index < modelsToInstall.install.length; index++) {
const model = modelsToInstall.install[index];
if (model) {
await installModel(model).unwrap();
}
}
toast({
status: 'info',
title: 'Bundle Installing',
description: `Installing ${modelsToInstall.install.length}, skipping ${modelsToInstall.skip.length} duplicates`,
});
}, [modelsToInstall, installModel]);
return (
<Tooltip
label={
<Flex flexDir="column">
<Text>{t('modelManager.includesNModels', { n: bundle.length })}</Text>
</Flex>
}
>
<Button flexDir="column" size="sm" onClick={handleClickBundle}>
{isMainModelBase(bundleName) && MODEL_TYPE_SHORT_MAP[bundleName]}
</Button>
</Tooltip>
);
};

View File

@@ -1,31 +1,19 @@
import { Flex } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
import { memo, useMemo } from 'react';
import {
modelConfigsAdapterSelectors,
useGetModelConfigsQuery,
useGetStarterModelsQuery,
} from 'services/api/endpoints/models';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
import { StarterModelsResults } from './StarterModelsResults';
export const StarterModelsForm = memo(() => {
const { isLoading, data } = useGetStarterModelsQuery();
const { data: modelListRes } = useGetModelConfigsQuery();
const modelList = useMemo(() => {
if (!modelListRes) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(modelListRes);
}, [modelListRes]);
const { t } = useTranslation();
return (
<Flex flexDir="column" height="100%" gap={3}>
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{data && <StarterModelsResults results={data} modelList={modelList} />}
{isLoading && <FetchingModelsLoader loadingMessage={t('common.loading')} />}
{data && <StarterModelsResults results={data} />}
</Flex>
);
});

View File

@@ -1,25 +1,24 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import { Flex, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { StarterBundle } from './StarterBundle';
import { StarterModelsResultItem } from './StartModelsResultItem';
type StarterModelsResultsProps = {
results: NonNullable<GetStarterModelsResponse>;
modelList: AnyModelConfig[];
};
export const StarterModelsResults = memo(({ results, modelList }: StarterModelsResultsProps) => {
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const filteredResults = useMemo(() => {
return results.filter((result) => {
return results.starter_models.filter((result) => {
const trimmedSearchTerm = searchTerm.trim().toLowerCase();
const matchStrings = [
result.name.toLowerCase(),
@@ -46,7 +45,21 @@ export const StarterModelsResults = memo(({ results, modelList }: StarterModelsR
return (
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="flex-end" alignItems="center">
<Flex justifyContent="space-between" alignItems="center">
{!!Object.keys(results.starter_bundles).length && (
<Flex gap={2} alignItems="center">
<Text fontWeight="semibold">{t('modelManager.starterBundles')}:</Text>
{Object.keys(results.starter_bundles).map((bundleName) => (
<>
{results.starter_bundles[bundleName] ? (
<StarterBundle bundleName={bundleName} bundle={results.starter_bundles[bundleName]} />
) : (
<></>
)}
</>
))}
</Flex>
)}
<InputGroup w={64} size="xs">
<Input
placeholder={t('modelManager.search')}
@@ -74,7 +87,7 @@ export const StarterModelsResults = memo(({ results, modelList }: StarterModelsR
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} result={result} modelList={modelList} />
<StarterModelsResultItem key={result.source} model={result} />
))}
</Flex>
</ScrollableContent>

View File

@@ -15166,6 +15166,15 @@ export type components = {
/** Dependencies */
dependencies?: components["schemas"]["StarterModelWithoutDependencies"][] | null;
};
/** StarterModelResponse */
StarterModelResponse: {
/** Starter Models */
starter_models: components["schemas"]["StarterModel"][];
/** Starter Bundles */
starter_bundles: {
[key: string]: components["schemas"]["StarterModel"][];
};
};
/** StarterModelWithoutDependencies */
StarterModelWithoutDependencies: {
/** Description */
@@ -17972,7 +17981,7 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["StarterModel"][];
"application/json": components["schemas"]["StarterModelResponse"];
};
};
};

View File

@@ -243,3 +243,4 @@ export type PostUploadAction =
| ReplaceLayerWithImagePostUploadAction;
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
export type StarterModel = S['StarterModel'];