diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 785a1af6fd..df938f89b1 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -41,6 +41,7 @@ from invokeai.backend.model_manager.starter_models import ( STARTER_BUNDLES, STARTER_MODELS, StarterModel, + StarterModelBundle, StarterModelWithoutDependencies, ) @@ -799,7 +800,7 @@ async def convert_model( class StarterModelResponse(BaseModel): starter_models: list[StarterModel] - starter_bundles: dict[str, list[StarterModel]] + starter_bundles: dict[str, StarterModelBundle] def get_is_installed( @@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse: model.dependencies = missing_deps for bundle in starter_bundles.values(): - for model in bundle: + for model in bundle.models: model.is_installed = get_is_installed(model, installed_models) # Remove already-installed dependencies missing_deps: list[StarterModelWithoutDependencies] = [] diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index b96f79249d..75e3716938 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies): dependencies: Optional[list[StarterModelWithoutDependencies]] = None -class StarterModelBundles(BaseModel): +class StarterModelBundle(BaseModel): name: str models: list[StarterModel] @@ -778,10 +778,10 @@ flux_bundle: list[StarterModel] = [ flux_fill, ] -STARTER_BUNDLES: dict[str, list[StarterModel]] = { - BaseModelType.StableDiffusion1: sd1_bundle, - BaseModelType.StableDiffusionXL: sdxl_bundle, - BaseModelType.Flux: flux_bundle, +STARTER_BUNDLES: dict[str, StarterModelBundle] = { + BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle), + BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle), + BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle), } assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"