feat(app): tweak model load events

- Pass in the `UtilInterface` to the `ModelsInterface` so we can call the simple `signal_progress` method instead of the complicated `emit_invocation_progress` method.
- Only emit load events when starting to load - not after.
- Add more detail to the messages, like submodel type
This commit is contained in:
psychedelicious
2024-11-08 06:47:14 +10:00
parent c7878fddc6
commit 067747eca9

View File

@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def save(
self,
image: Image,
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
The saved image DTO.
"""
self._util.signal_progress("Saving image")
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists.
@@ -372,15 +382,11 @@ class ModelsInterface(InvocationContextInterface):
submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
try:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Loading model {model.name}..."
)
return self._services.model_manager.load.load_model(model, submodel_type)
finally:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Finished loading model {model.name}."
)
message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(model, submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@@ -405,15 +411,11 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
try:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Loading model {name}..."
)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
finally:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Finished loading model {name}."
)
message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config.
@@ -483,15 +485,8 @@ class ModelsInterface(InvocationContextInterface):
Returns:
Path to the downloaded model
"""
try:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Downloading model {source}..."
)
return self._services.model_manager.install.download_and_cache_model(source=source)
finally:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Finished downloading model {source}."
)
self._util.signal_progress(f"Downloading model {source}")
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
@@ -514,15 +509,8 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
try:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, "Loading model..."
)
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
finally:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, "Finished loading model."
)
self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
@@ -548,15 +536,8 @@ class ModelsInterface(InvocationContextInterface):
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
try:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Loading model {source}..."
)
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
finally:
self._services.events.emit_invocation_progress(
self._data.queue_item, self._data.invocation, f"Finished loading model {source}."
)
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):
@@ -749,12 +730,12 @@ def build_invocation_context(
"""
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data, util=util)
images = ImagesInterface(services=services, data=data, util=util)
boards = BoardsInterface(services=services, data=data)
ctx = InvocationContext(