add the import model router

This commit is contained in:
Lincoln Stein
2023-07-03 19:32:54 -04:00
committed by psychedelicious
parent 0988725c1b
commit 96bf92ead4
8 changed files with 233 additions and 1489 deletions

View File

@@ -2,17 +2,17 @@
from typing import Literal, Optional, Union
from fastapi import Query
from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
@@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
@@ -86,7 +89,6 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
"/",
operation_id="update_model",
@@ -109,27 +111,38 @@ async def update_model(
return model_response
@models_router.post(
"/",
"/import",
operation_id="import_model",
responses={200: {"status": "success"}},
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
name: str = Query(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
)
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found')
@models_router.delete(
"/{model_name}",

View File

@@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
"""
@@ -361,3 +384,24 @@ class ModelManagerService(ModelManagerServiceBase):
def logger(self):
return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)