diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9ba6a42554..e59b88dc7c 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -88,7 +88,9 @@ class ApiDependencies: latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) - model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events) + model_install_service = ModelInstallService( + app_config=config, record_store=model_record_service, event_bus=events + ) names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 505d998bc2..90c58db3a1 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -51,7 +51,9 @@ async def list_model_records( found_models: list[AnyModelConfig] = [] if base_models: for base_model in base_models: - found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)) + found_models.extend( + record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name) + ) else: found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name)) return ModelsList(models=found_models) @@ -184,25 +186,25 @@ async def add_model_record( status_code=201, ) async def import_model( - source: ModelSource = Body( - description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!" - ), - config: Optional[Dict[str, Any]] = Body( - description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", - default=None, - ), - variant: Optional[str] = Body( - description="When fetching a repo_id, force variant type to fetch such as 'fp16'", - default=None, - ), - subfolder: Optional[str] = Body( - description="When fetching a repo_id, specify subfolder to fetch model from", - default=None, - ), - access_token: Optional[str] = Body( - description="When fetching a repo_id or URL, access token for web access", - default=None, - ), + source: ModelSource = Body( + description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!" + ), + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + ), + variant: Optional[str] = Body( + description="When fetching a repo_id, force variant type to fetch such as 'fp16'", + default=None, + ), + subfolder: Optional[str] = Body( + description="When fetching a repo_id, specify subfolder to fetch model from", + default=None, + ), + access_token: Optional[str] = Body( + description="When fetching a repo_id or URL, access token for web access", + default=None, + ), ) -> ModelInstallJob: """Add a model using its local path, repo_id, or remote URL. @@ -250,14 +252,16 @@ async def import_model( raise HTTPException(status_code=409, detail=str(e)) return result + @model_records_router.get( "/import", operation_id="list_model_install_jobs", ) async def list_model_install_jobs( - source: Optional[str] = Query(description="Filter list by install source, partial string match.", - default=None, - ) + source: Optional[str] = Query( + description="Filter list by install source, partial string match.", + default=None, + ), ) -> List[ModelInstallJob]: """ Return list of model install jobs. @@ -268,6 +272,7 @@ async def list_model_install_jobs( jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source) return jobs + @model_records_router.patch( "/import", operation_id="prune_model_install_jobs", @@ -276,14 +281,14 @@ async def list_model_install_jobs( 400: {"description": "Bad request"}, }, ) -async def prune_model_install_jobs( -) -> Response: +async def prune_model_install_jobs() -> Response: """ Prune all completed and errored jobs from the install job list. """ ApiDependencies.invoker.services.model_install.prune_jobs() return Response(status_code=204) + @model_records_router.patch( "/sync", operation_id="sync_models_to_config", @@ -292,8 +297,7 @@ async def prune_model_install_jobs( 400: {"description": "Bad request"}, }, ) -async def sync_models_to_config( -) -> Response: +async def sync_models_to_config() -> Response: """ Traverse the models and autoimport directories. Model files without a corresponding record in the database are added. Orphan records without a models file are deleted. diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index ff3957507d..c63297fa55 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -37,9 +37,5 @@ class SocketIO: if "queue_id" in data: await self.__sio.leave_room(sid, data["queue_id"]) - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"] - ) + await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) diff --git a/invokeai/app/services/config/__init__.py b/invokeai/app/services/config/__init__.py index e0b4168c6f..4cc6ecc298 100644 --- a/invokeai/app/services/config/__init__.py +++ b/invokeai/app/services/config/__init__.py @@ -2,4 +2,4 @@ from .config_default import InvokeAIAppConfig, get_invokeai_config -__all__ = ['InvokeAIAppConfig', 'get_invokeai_config'] +__all__ = ["InvokeAIAppConfig", "get_invokeai_config"] diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e46c30e755..93b84afaf1 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -331,9 +331,7 @@ class EventServiceBase: """ self.__emit_model_event( event_name="model_install_started", - payload={ - "source": source - }, + payload={"source": source}, ) def emit_model_install_completed(self, source: str, key: str) -> None: @@ -351,11 +349,12 @@ class EventServiceBase: }, ) - def emit_model_install_progress(self, - source: str, - current_bytes: int, - total_bytes: int, - ) -> None: + def emit_model_install_progress( + self, + source: str, + current_bytes: int, + total_bytes: int, + ) -> None: """ Emitted while the install job is in progress. (Downloaded models only) @@ -373,12 +372,12 @@ class EventServiceBase: }, ) - - def emit_model_install_error(self, - source: str, - error_type: str, - error: str, - ) -> None: + def emit_model_install_error( + self, + source: str, + error_type: str, + error: str, + ) -> None: """ Emitted when an install job encounters an exception. diff --git a/invokeai/app/services/model_install/__init__.py b/invokeai/app/services/model_install/__init__.py index e45a15f503..e86e18863d 100644 --- a/invokeai/app/services/model_install/__init__.py +++ b/invokeai/app/services/model_install/__init__.py @@ -9,10 +9,11 @@ from .model_install_base import ( ) from .model_install_default import ModelInstallService -__all__ = ['ModelInstallServiceBase', - 'ModelInstallService', - 'InstallStatus', - 'ModelInstallJob', - 'UnknownInstallJobException', - 'ModelSource', - ] +__all__ = [ + "ModelInstallServiceBase", + "ModelInstallService", + "InstallStatus", + "ModelInstallJob", + "UnknownInstallJobException", + "ModelSource", +] diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 5212443b68..0a4c17559c 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -17,10 +17,10 @@ from invokeai.backend.model_manager import AnyModelConfig class InstallStatus(str, Enum): """State of an install job running in the background.""" - WAITING = "waiting" # waiting to be dequeued - RUNNING = "running" # being processed + WAITING = "waiting" # waiting to be dequeued + RUNNING = "running" # being processed COMPLETED = "completed" # finished running - ERROR = "error" # terminated with an error message + ERROR = "error" # terminated with an error message class UnknownInstallJobException(Exception): @@ -32,10 +32,17 @@ ModelSource = Union[str, Path, AnyHttpUrl] class ModelInstallJob(BaseModel): """Object that tracks the current status of an install request.""" + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") - config_in: Dict[str, Any] = Field(default_factory=dict, description="Configuration information (e.g. 'description') to apply to model.") - config_out: Optional[AnyModelConfig] = Field(default=None, description="After successful installation, this will hold the configuration object.") - inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory") + config_in: Dict[str, Any] = Field( + default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." + ) + config_out: Optional[AnyModelConfig] = Field( + default=None, description="After successful installation, this will hold the configuration object." + ) + inplace: bool = Field( + default=False, description="Leave model in its current location; otherwise install under models directory" + ) source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR") @@ -53,10 +60,10 @@ class ModelInstallServiceBase(ABC): @abstractmethod def __init__( - self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - event_bus: Optional["EventServiceBase"] = None, + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + event_bus: Optional["EventServiceBase"] = None, ): """ Create ModelInstallService object. @@ -86,9 +93,9 @@ class ModelInstallServiceBase(ABC): @abstractmethod def register_path( - self, - model_path: Union[Path, str], - config: Optional[Dict[str, Any]] = None, + self, + model_path: Union[Path, str], + config: Optional[Dict[str, Any]] = None, ) -> str: """ Probe and register the model at model_path. @@ -114,9 +121,9 @@ class ModelInstallServiceBase(ABC): @abstractmethod def install_path( - self, - model_path: Union[Path, str], - config: Optional[Dict[str, Any]] = None, + self, + model_path: Union[Path, str], + config: Optional[Dict[str, Any]] = None, ) -> str: """ Probe, register and install the model in the models directory. @@ -131,13 +138,13 @@ class ModelInstallServiceBase(ABC): @abstractmethod def import_model( - self, - source: Union[str, Path, AnyHttpUrl], - inplace: bool = False, - variant: Optional[str] = None, - subfolder: Optional[str] = None, - config: Optional[Dict[str, Any]] = None, - access_token: Optional[str] = None, + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = False, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, ) -> ModelInstallJob: """Install the indicated model. @@ -189,7 +196,7 @@ class ModelInstallServiceBase(ABC): """Return the ModelInstallJob corresponding to the provided source.""" @abstractmethod - def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102 + def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102 """ List active and complete install jobs. diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c67b8feb59..17a0fca7bf 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -46,11 +46,12 @@ class ModelInstallService(ModelInstallServiceBase): _cached_model_paths: Set[Path] _models_installed: Set[str] - def __init__(self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - event_bus: Optional[EventServiceBase] = None - ): + def __init__( + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + event_bus: Optional[EventServiceBase] = None, + ): """ Initialize the installer object. @@ -73,11 +74,11 @@ class ModelInstallService(ModelInstallServiceBase): return self._app_config @property - def record_store(self) -> ModelRecordServiceBase: # noqa D102 + def record_store(self) -> ModelRecordServiceBase: # noqa D102 return self._record_store @property - def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 + def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 return self._event_bus def _start_installer_thread(self) -> None: @@ -129,25 +130,25 @@ class ModelInstallService(ModelInstallServiceBase): self._event_bus.emit_model_install_error(str(job.source), error_type, error) def register_path( - self, - model_path: Union[Path, str], - config: Optional[Dict[str, Any]] = None, - ) -> str: # noqa D102 + self, + model_path: Union[Path, str], + config: Optional[Dict[str, Any]] = None, + ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get('source') is None: - config['source'] = model_path.resolve().as_posix() + if config.get("source") is None: + config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) def install_path( - self, - model_path: Union[Path, str], - config: Optional[Dict[str, Any]] = None, - ) -> str: # noqa D102 + self, + model_path: Union[Path, str], + config: Optional[Dict[str, Any]] = None, + ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get('source') is None: - config['source'] = model_path.resolve().as_posix() + if config.get("source") is None: + config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) @@ -164,14 +165,14 @@ class ModelInstallService(ModelInstallServiceBase): ) def import_model( - self, - source: ModelSource, - inplace: bool = False, - variant: Optional[str] = None, - subfolder: Optional[str] = None, - config: Optional[Dict[str, Any]] = None, - access_token: Optional[str] = None, - ) -> ModelInstallJob: # noqa D102 + self, + source: ModelSource, + inplace: bool = False, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: # noqa D102 # Clean up a common source of error. Doesn't work with Paths. if isinstance(source, str): source = source.strip() @@ -181,11 +182,12 @@ class ModelInstallService(ModelInstallServiceBase): # Installing a local path if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk - job = ModelInstallJob(config_in=config, - source=source, - inplace=inplace, - local_path=Path(source), - ) + job = ModelInstallJob( + config_in=config, + source=source, + inplace=inplace, + local_path=Path(source), + ) self._install_jobs[source] = job self._install_queue.put(job) return job @@ -193,7 +195,7 @@ class ModelInstallService(ModelInstallServiceBase): else: # here is where we'd download a URL or repo_id. Implementation pending download queue. raise UnknownModelException("File or directory not found") - def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102 + def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102 jobs = self._install_jobs if not source: return list(jobs.values()) @@ -205,17 +207,19 @@ class ModelInstallService(ModelInstallServiceBase): try: return self._install_jobs[source] except KeyError: - raise UnknownInstallJobException(f'{source}: unknown install job') + raise UnknownInstallJobException(f"{source}: unknown install job") - def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102 + def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102 self._install_queue.join() return self._install_jobs def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" - finished_jobs = [source for source in self._install_jobs - if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR] - ] + finished_jobs = [ + source + for source in self._install_jobs + if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR] + ] for source in finished_jobs: del self._install_jobs[source] @@ -228,7 +232,7 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"{len(installed)} new models registered") self._logger.info("Model installer (re)initialized") - def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 + def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} callback = self._scan_install if install else self._scan_register search = ModelSearch(on_model_found=callback) @@ -295,7 +299,6 @@ class ModelInstallService(ModelInstallServiceBase): self.record_store.update_model(key, model) return model - def _scan_register(self, model: Path) -> bool: if model in self._cached_model_paths: return True @@ -308,7 +311,6 @@ class ModelInstallService(ModelInstallServiceBase): pass return True - def _scan_install(self, model: Path) -> bool: if model in self._cached_model_paths: return True @@ -320,7 +322,7 @@ class ModelInstallService(ModelInstallServiceBase): pass return True - def unregister(self, key: str) -> None: # noqa D102 + def unregister(self, key: str) -> None: # noqa D102 self.record_store.del_model(key) def delete(self, key: str) -> None: # noqa D102 @@ -333,7 +335,7 @@ class ModelInstallService(ModelInstallServiceBase): else: self.unregister(key) - def unconditionally_delete(self, key: str) -> None: # noqa D102 + def unconditionally_delete(self, key: str) -> None: # noqa D102 model = self.record_store.get_model(key) path = self.app_config.models_path / model.path if path.is_dir(): @@ -378,11 +380,9 @@ class ModelInstallService(ModelInstallServiceBase): def _create_key(self) -> str: return sha256(randbytes(100)).hexdigest()[0:32] - def _register(self, - model_path: Path, - config: Optional[Dict[str, Any]] = None, - info: Optional[AnyModelConfig] = None) -> str: - + def _register( + self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None + ) -> str: info = info or ModelProbe.probe(model_path, config) key = self._create_key() @@ -393,7 +393,7 @@ class ModelInstallService(ModelInstallServiceBase): info.path = model_path.as_posix() # add 'main' specific fields - if hasattr(info, 'config'): + if hasattr(info, "config"): # make config relative to our root legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve() info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix() diff --git a/invokeai/app/services/model_records/__init__.py b/invokeai/app/services/model_records/__init__.py index a7ebacee67..160608fd5d 100644 --- a/invokeai/app/services/model_records/__init__.py +++ b/invokeai/app/services/model_records/__init__.py @@ -8,9 +8,9 @@ from .model_records_base import ( # noqa F401 from .model_records_sql import ModelRecordServiceSQL # noqa F401 __all__ = [ - 'ModelRecordServiceBase', - 'ModelRecordServiceSQL', - 'DuplicateModelException', - 'InvalidModelException', - 'UnknownModelException', + "ModelRecordServiceBase", + "ModelRecordServiceSQL", + "DuplicateModelException", + "InvalidModelException", + "UnknownModelException", ] diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 1601b45f57..775d3404cb 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -123,8 +123,8 @@ class ModelProbe(object): base_type=base_type, variant_type=variant_type, prediction_type=prediction_type, - name = name, - description = description, + name=name, + description=description, upcast_attention=( base_type == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction @@ -150,7 +150,7 @@ class ModelProbe(object): @classmethod def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}: + if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: return model_path.stem else: return model_path.name diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 5988dfb686..bd2828312a 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -14,15 +14,16 @@ from .config import ( from .probe import ModelProbe from .search import ModelSearch -__all__ = ['ModelProbe', 'ModelSearch', - 'InvalidModelConfigException', - 'ModelConfigFactory', - 'BaseModelType', - 'ModelType', - 'SubModelType', - 'ModelVariantType', - 'ModelFormat', - 'SchedulerPredictionType', - 'AnyModelConfig', - ] - +__all__ = [ + "ModelProbe", + "ModelSearch", + "InvalidModelConfigException", + "ModelConfigFactory", + "BaseModelType", + "ModelType", + "SubModelType", + "ModelVariantType", + "ModelFormat", + "SchedulerPredictionType", + "AnyModelConfig", +] diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 0caa191f3e..22c67742d8 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -49,6 +49,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched }, } + class ProbeBase(object): """Base class for probes.""" @@ -71,6 +72,7 @@ class ProbeBase(object): """Get model scheduler prediction type.""" return None + class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { "diffusers": {}, @@ -100,9 +102,9 @@ class ModelProbe(object): @classmethod def heuristic_probe( - cls, - model_path: Path, - fields: Optional[Dict[str, Any]] = None, + cls, + model_path: Path, + fields: Optional[Dict[str, Any]] = None, ) -> AnyModelConfig: return cls.probe(model_path, fields) @@ -138,29 +140,38 @@ class ModelProbe(object): hash = FastModelHash.hash(model_path) probe = probe_class(model_path) - fields['path'] = model_path.as_posix() - fields['type'] = fields.get('type') or model_type - fields['base'] = fields.get('base') or probe.get_base_type() - fields['variant'] = fields.get('variant') or probe.get_variant_type() - fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type() - fields['name'] = fields.get('name') or cls.get_model_name(model_path) - fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" - fields['format'] = fields.get('format') or probe.get_format() - fields['original_hash'] = fields.get('original_hash') or hash - fields['current_hash'] = fields.get('current_hash') or hash + fields["path"] = model_path.as_posix() + fields["type"] = fields.get("type") or model_type + fields["base"] = fields.get("base") or probe.get_base_type() + fields["variant"] = fields.get("variant") or probe.get_variant_type() + fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() + fields["name"] = fields.get("name") or cls.get_model_name(model_path) + fields["description"] = ( + fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" + ) + fields["format"] = fields.get("format") or probe.get_format() + fields["original_hash"] = fields.get("original_hash") or hash + fields["current_hash"] = fields.get("current_hash") or hash # additional fields needed for main and controlnet models - if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint: - fields['config'] = cls._get_checkpoint_config_path(model_path, - model_type=fields['type'], - base_type=fields['base'], - variant_type=fields['variant'], - prediction_type=fields['prediction_type']).as_posix() + if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: + fields["config"] = cls._get_checkpoint_config_path( + model_path, + model_type=fields["type"], + base_type=fields["base"], + variant_type=fields["variant"], + prediction_type=fields["prediction_type"], + ).as_posix() # additional fields needed for main non-checkpoint models - elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]: - fields['upcast_attention'] = fields.get('upcast_attention') or ( - fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction + elif fields["type"] == ModelType.Main and fields["format"] in [ + ModelFormat.Onnx, + ModelFormat.Olive, + ModelFormat.Diffusers, + ]: + fields["upcast_attention"] = fields.get("upcast_attention") or ( + fields["base"] == BaseModelType.StableDiffusion2 + and fields["prediction_type"] == SchedulerPredictionType.VPrediction ) model_info = ModelConfigFactory.make_config(fields) @@ -168,7 +179,7 @@ class ModelProbe(object): @classmethod def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}: + if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: return model_path.stem else: return model_path.name @@ -247,13 +258,14 @@ class ModelProbe(object): ) @classmethod - def _get_checkpoint_config_path(cls, - model_path: Path, - model_type: ModelType, - base_type: BaseModelType, - variant_type: ModelVariantType, - prediction_type: SchedulerPredictionType) -> Path: - + def _get_checkpoint_config_path( + cls, + model_path: Path, + model_type: ModelType, + base_type: BaseModelType, + variant_type: ModelVariantType, + prediction_type: SchedulerPredictionType, + ) -> Path: # look for a YAML file adjacent to the model file first possible_conf = model_path.with_suffix(".yaml") if possible_conf.exists(): @@ -264,9 +276,13 @@ class ModelProbe(object): if isinstance(config_file, dict): # need another tier for sd-2.x models config_file = config_file[prediction_type] elif model_type == ModelType.ControlNet: - config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" + config_file = ( + "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" + ) else: - raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}") + raise InvalidModelConfigException( + f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" + ) assert isinstance(config_file, str) return Path(config_file) @@ -297,6 +313,7 @@ class ModelProbe(object): # Checkpoint probing # ##################################################3 + class CheckpointProbeBase(ProbeBase): def __init__(self, model_path: Path): super().__init__(model_path) @@ -446,7 +463,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase): # classes for probing folders ####################################################### class FolderProbeBase(ProbeBase): - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal @@ -537,7 +553,9 @@ class TextualInversionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: path = self.model_path / "learned_embeds.bin" if not path.exists(): - raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file") + raise InvalidModelConfigException( + f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" + ) return TextualInversionCheckpointProbe(path).get_base_type() @@ -608,7 +626,9 @@ class IPAdapterFolderProbe(FolderProbeBase): elif cross_attention_dim == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") + raise InvalidModelConfigException( + f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." + ) class CLIPVisionFolderProbe(FolderProbeBase): diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 06d138fc9f..7492e471d3 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -165,14 +165,14 @@ class ModelSearch(ModelSearchBase): self.scanned_dirs.add(path) continue if any( - (path / x).exists() - for x in [ - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - ] + (path / x).exists() + for x in [ + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "image_encoder.txt", + ] ): self.scanned_dirs.add(path) try: diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 800a10614c..87ae1480f5 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -14,4 +14,4 @@ from .devices import ( # noqa: F401 from .logging import InvokeAILogger from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 -__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device'] +__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 67278a2a0b..75069b37eb 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -44,12 +44,12 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: @pytest.fixture -def installer(app_config: InvokeAIAppConfig, - store: ModelRecordServiceBase) -> ModelInstallServiceBase: - return ModelInstallService(app_config=app_config, - record_store=store, - event_bus=DummyEventService(), - ) +def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase: + return ModelInstallService( + app_config=app_config, + record_store=store, + event_bus=DummyEventService(), + ) class DummyEvent(BaseModel): @@ -70,10 +70,8 @@ class DummyEventService(EventServiceBase): def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" - self.events.append( - DummyEvent(event_name=payload['event'], - payload=payload['data']) - ) + self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) + def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None: store = installer.record_store @@ -83,6 +81,7 @@ def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> No assert key is not None assert len(key) == 32 + def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None: store = installer.record_store key = installer.register_path(test_file) @@ -91,31 +90,30 @@ def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) assert model_record.name == "test_embedding" assert model_record.type == ModelType.TextualInversion assert Path(model_record.path) == test_file - assert model_record.base == BaseModelType('sd-1') + assert model_record.base == BaseModelType("sd-1") assert model_record.description is not None assert model_record.source is not None assert Path(model_record.source) == test_file + def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None: key = None with pytest.raises(ValidationError): key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")}) assert key is None + def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None: store = installer.record_store - key = installer.register_path(test_file, - { - "name": "banana_sushi", - "source": "fake/repo_id", - "current_hash": "New Hash" - } - ) + key = installer.register_path( + test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"} + ) model_record = store.get_model(key) assert model_record.name == "banana_sushi" assert model_record.source == "fake/repo_id" assert model_record.current_hash == "New Hash" + def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None: store = installer.record_store key = installer.install_path(test_file) @@ -123,6 +121,7 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config assert model_record.path == "sd-1/embedding/test_embedding.safetensors" assert model_record.source == test_file.as_posix() + def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None: """Note: may want to break this down into several smaller unit tests.""" source = test_file @@ -142,7 +141,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path, # test that the expected events were issued bus = installer.event_bus - assert bus is not None # sigh - ruff is a stickler for type checking + assert bus is not None # sigh - ruff is a stickler for type checking assert isinstance(bus, DummyEventService) assert len(bus.events) == 2 event_names = [x.event_name for x in bus.events] @@ -167,6 +166,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path, with pytest.raises(UnknownInstallJobException): assert installer.get_job(source) + def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig): store = installer.record_store key = installer.install_path(test_file) @@ -174,11 +174,14 @@ def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app assert Path(app_config.models_dir / model_record.path).exists() assert test_file.exists() # original should still be there after installation installer.delete(key) - assert not Path(app_config.models_dir / model_record.path).exists() # after deletion, installed copy should not exist + assert not Path( + app_config.models_dir / model_record.path + ).exists() # after deletion, installed copy should not exist assert test_file.exists() # but original should still be there with pytest.raises(UnknownModelException): store.get_model(key) + def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig): store = installer.record_store key = installer.register_path(test_file)