mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(app): tweak model load events
- Pass in the `UtilInterface` to the `ModelsInterface` so we can call the simple `signal_progress` method instead of the complicated `emit_invocation_progress` method. - Only emit load events when starting to load - not after. - Add more detail to the messages, like submodel type
This commit is contained in:
@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
self._util.signal_progress("Saving image")
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Check if a model exists.
|
||||
|
||||
@@ -372,15 +382,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
try:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Loading model {model.name}..."
|
||||
)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
finally:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Finished loading model {model.name}."
|
||||
)
|
||||
message = f"Loading model {model.name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -405,15 +411,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
try:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Loading model {name}..."
|
||||
)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
finally:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Finished loading model {name}."
|
||||
)
|
||||
message = f"Loading model {name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Get a model's config.
|
||||
@@ -483,15 +485,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
try:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Downloading model {source}..."
|
||||
)
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
finally:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Finished downloading model {source}."
|
||||
)
|
||||
self._util.signal_progress(f"Downloading model {source}")
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
@@ -514,15 +509,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, "Loading model..."
|
||||
)
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
finally:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, "Finished loading model."
|
||||
)
|
||||
self._util.signal_progress(f"Loading model {model_path.name}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
@@ -548,15 +536,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
|
||||
try:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Loading model {source}..."
|
||||
)
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
finally:
|
||||
self._services.events.emit_invocation_progress(
|
||||
self._data.queue_item, self._data.invocation, f"Finished loading model {source}."
|
||||
)
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
@@ -749,12 +730,12 @@ def build_invocation_context(
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data, util=util)
|
||||
images = ImagesInterface(services=services, data=data, util=util)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
|
||||
Reference in New Issue
Block a user