mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
164 lines
5.6 KiB
Python
164 lines
5.6 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein
|
|
"""FastAPI route for model configuration records."""
|
|
|
|
|
|
from hashlib import sha1
|
|
from random import randbytes
|
|
from typing import List, Optional
|
|
|
|
from fastapi import Body, Path, Query, Response
|
|
from fastapi.routing import APIRouter
|
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
|
from starlette.exceptions import HTTPException
|
|
|
|
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
|
|
|
from ..dependencies import ApiDependencies
|
|
|
|
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
|
|
|
|
ModelConfigValidator = TypeAdapter(AnyModelConfig)
|
|
|
|
|
|
class ModelsList(BaseModel):
|
|
"""Return list of configs."""
|
|
|
|
models: list[AnyModelConfig]
|
|
|
|
model_config = ConfigDict(use_enum_values=True)
|
|
|
|
|
|
ModelsListValidator = TypeAdapter(ModelsList)
|
|
|
|
|
|
@model_records_router.get(
|
|
"/",
|
|
operation_id="list_model_configs",
|
|
responses={200: {"model": ModelsList}},
|
|
)
|
|
async def list_model_records(
|
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
|
) -> ModelsList:
|
|
"""Get a list of models."""
|
|
record_store = ApiDependencies.invoker.services.model_records
|
|
if base_models and len(base_models) > 0:
|
|
models_raw = list()
|
|
for base_model in base_models:
|
|
models_raw.extend(
|
|
[x.model_dump() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
|
|
)
|
|
else:
|
|
models_raw = [x.model_dump() for x in record_store.search_by_name(model_type=model_type)]
|
|
models = ModelsListValidator.validate_python({"models": models_raw})
|
|
return models
|
|
|
|
|
|
@model_records_router.get(
|
|
"/i/{key}",
|
|
operation_id="get_model_record",
|
|
responses={
|
|
200: {"description": "Success"},
|
|
400: {"description": "Bad request"},
|
|
404: {"description": "The model could not be found"},
|
|
},
|
|
)
|
|
async def get_model_record(
|
|
key: str = Path(description="Key of the model record to fetch."),
|
|
) -> AnyModelConfig:
|
|
"""Get a model record"""
|
|
record_store = ApiDependencies.invoker.services.model_records
|
|
try:
|
|
return record_store.get_model(key)
|
|
except UnknownModelException as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@model_records_router.patch(
|
|
"/i/{key}",
|
|
operation_id="update_model_record",
|
|
responses={
|
|
200: {"description": "The model was updated successfully"},
|
|
400: {"description": "Bad request"},
|
|
404: {"description": "The model could not be found"},
|
|
409: {"description": "There is already a model corresponding to the new name"},
|
|
},
|
|
status_code=200,
|
|
response_model=AnyModelConfig,
|
|
)
|
|
async def update_model_record(
|
|
key: str = Path(description="Unique key of model"),
|
|
info: AnyModelConfig = Body(description="Model configuration"),
|
|
) -> AnyModelConfig:
|
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
|
logger = ApiDependencies.invoker.services.logger
|
|
record_store = ApiDependencies.invoker.services.model_records
|
|
try:
|
|
model_response = record_store.update_model(key, config=info)
|
|
except UnknownModelException as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except ValueError as e:
|
|
logger.error(str(e))
|
|
raise HTTPException(status_code=409, detail=str(e))
|
|
return model_response
|
|
|
|
|
|
@model_records_router.delete(
|
|
"/i/{key}",
|
|
operation_id="del_model_record",
|
|
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
|
status_code=204,
|
|
response_model=None,
|
|
)
|
|
async def del_model_record(
|
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
|
) -> Response:
|
|
"""Delete Model"""
|
|
logger = ApiDependencies.invoker.services.logger
|
|
|
|
try:
|
|
record_store = ApiDependencies.invoker.services.model_records
|
|
record_store.del_model(key)
|
|
logger.info(f"Deleted model: {key}")
|
|
return Response(status_code=204)
|
|
except UnknownModelException as e:
|
|
logger.error(str(e))
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@model_records_router.post(
|
|
"/i/",
|
|
operation_id="add_model_record",
|
|
responses={
|
|
201: {"description": "The model added successfully"},
|
|
404: {"description": "The model could not be found"},
|
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
|
415: {"description": "Unrecognized file/folder format"},
|
|
},
|
|
status_code=201,
|
|
response_model=AnyModelConfig,
|
|
)
|
|
async def add_model_record(
|
|
config: AnyModelConfig = Body(description="Model configuration"),
|
|
) -> AnyModelConfig:
|
|
"""
|
|
Add a model using the configuration information appropriate for its type.
|
|
"""
|
|
logger = ApiDependencies.invoker.services.logger
|
|
record_store = ApiDependencies.invoker.services.model_records
|
|
if config.key == "<NOKEY>":
|
|
config.key = sha1(randbytes(100)).hexdigest()
|
|
logger.info(f"Created model {config.key} for {config.name}")
|
|
try:
|
|
record_store.add_model(config.key, config)
|
|
except DuplicateModelException as e:
|
|
logger.error(str(e))
|
|
raise HTTPException(status_code=409, detail=str(e))
|
|
except InvalidModelException as e:
|
|
logger.error(str(e))
|
|
raise HTTPException(status_code=415)
|
|
|
|
# now fetch it out
|
|
return record_store.get_model(config.key)
|