mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 13:08:12 -05:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21b9e96a45 | ||
|
|
b6ad33ac1a | ||
|
|
69ec14c7bb | ||
|
|
a6c91979af | ||
|
|
e655399324 | ||
|
|
f75de8a35c | ||
|
|
d4be945dde | ||
|
|
ab33acad5c | ||
|
|
8f3d7b2946 | ||
|
|
54a30f66cb | ||
|
|
a105da6304 | ||
|
|
4049217728 | ||
|
|
59b4a23479 | ||
|
|
13f410478a | ||
|
|
25ff0bf80f | ||
|
|
f83edcf990 | ||
|
|
a6dd50aeaf | ||
|
|
1badf0f32f | ||
|
|
3c9c58e0fa | ||
|
|
9a1b35fa37 | ||
|
|
5be69f191d | ||
|
|
3d6d89feb4 | ||
|
|
0ac1c0f339 | ||
|
|
c308654442 | ||
|
|
b0ffe36d21 | ||
|
|
6b3fdb8a93 | ||
|
|
7639e05dd2 | ||
|
|
6d261a5a13 | ||
|
|
31e9cf1f06 | ||
|
|
c5d1bd1360 | ||
|
|
3409711ed3 | ||
|
|
3681e34d5a | ||
|
|
2526ef52c5 | ||
|
|
43bcedee10 | ||
|
|
98cc9b963c |
@@ -614,8 +614,8 @@ async def convert_model(
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
loader = model_manager.load
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
@@ -630,7 +630,13 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
||||
@@ -9,8 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -81,8 +80,7 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
@@ -183,8 +181,7 @@ class SDXLPromptInvocationBase:
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
@@ -262,15 +259,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
|
||||
@@ -3,6 +3,7 @@ Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
@@ -41,11 +42,15 @@ for d in Path(__file__).parent.iterdir():
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
loaded_count += 1
|
||||
except Exception:
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
||||
@@ -52,8 +52,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
@@ -740,8 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@@ -373,13 +373,16 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
# Else we do not attempt to migrate this setting
|
||||
if v != "configs/stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
|
||||
@@ -348,8 +348,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config["config_path"] = stanza.get("config")
|
||||
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@@ -368,11 +373,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||
if model_path.is_relative_to(models_dir):
|
||||
model_path = self.app_config.models_path / model.path
|
||||
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
# If the models is in the Invoke-managed models dir, we delete it
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
# Else we only unregister it, leaving the file in place
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
@@ -500,9 +507,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
|
||||
"""Scan the models directory for missing models and return a list of them."""
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for x in self.record_store.all_models():
|
||||
if not Path(x.path).resolve().exists():
|
||||
missing_models.append(x)
|
||||
for model_config in self.record_store.all_models():
|
||||
if not (self.app_config.models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
return missing_models
|
||||
|
||||
def _register_orphaned_models(self) -> None:
|
||||
@@ -512,7 +519,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
|
||||
# The bool returned by this callback determines if the model is added to the list of models found by the search
|
||||
def on_model_found(model_path: Path) -> bool:
|
||||
@@ -548,10 +557,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self.record_store.get_model(key)
|
||||
old_path = Path(model.path).resolve()
|
||||
models_dir = self.app_config.models_path.resolve()
|
||||
models_dir = self.app_config.models_path
|
||||
old_path = self.app_config.models_path / model.path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
# The model is not in the models directory - we don't need to move it.
|
||||
return model
|
||||
|
||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||
@@ -561,7 +571,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.as_posix()
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
@@ -600,12 +610,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
# Models in the Invoke-managed models dir should use relative paths.
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
|
||||
# invoke-managed legacy config dir, we use a relative path.
|
||||
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
|
||||
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
|
||||
info.config_path = legacy_config_path.as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
|
||||
@@ -70,8 +70,18 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if event_name == "session_canceled" or event_name == "queue_cleared":
|
||||
# These both mean we should cancel the current session.
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
@@ -111,141 +121,146 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# If we are paused, wait for resume event
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
if self._queue_item is not None and resume_event.is_set():
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
if self._queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
poll_now_event.set()
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
|
||||
@@ -10,6 +10,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -37,6 +39,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -11,7 +11,7 @@ class Migration7Callback:
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
tables = ["model_config", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration8Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig) -> None:
|
||||
self._app_config = app_config
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_config_table(cursor)
|
||||
self._migrate_abs_models_to_rel(cursor)
|
||||
|
||||
def _drop_model_config_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_config table. This was missed in a previous migration."""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
def _migrate_abs_models_to_rel(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Check all model paths & legacy config paths to determine if they are inside Invoke-managed directories. If
|
||||
they are, update the paths to be relative to the managed directories.
|
||||
|
||||
This migration is a no-op for normal users (their paths will already be relative), but is necessary for users
|
||||
who have been testing the RCs with their live databases. The paths were made absolute in the initial RC, but this
|
||||
change was reverted. To smooth over the revert for our tests, we can migrate the paths back to relative.
|
||||
"""
|
||||
|
||||
models_path = self._app_config.models_path
|
||||
legacy_conf_path = self._app_config.legacy_conf_path
|
||||
legacy_conf_dir = self._app_config.legacy_conf_dir
|
||||
|
||||
stmt = """---sql
|
||||
SELECT
|
||||
id,
|
||||
path,
|
||||
json_extract(config, '$.config_path') as config_path
|
||||
FROM models;
|
||||
"""
|
||||
|
||||
all_models = cursor.execute(stmt).fetchall()
|
||||
|
||||
for model_id, model_path, model_config_path in all_models:
|
||||
# If the model path is inside the models directory, update it to be relative to the models directory.
|
||||
if Path(model_path).is_relative_to(models_path):
|
||||
new_path = Path(model_path).relative_to(models_path)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_path), model_id),
|
||||
)
|
||||
# If the model has a legacy config path and it is inside the legacy conf directory, update it to be
|
||||
# relative to the legacy conf directory. This also fixes up cases in which the config path was
|
||||
# incorrectly relativized to the root directory. It will now be relativized to the legacy conf directory.
|
||||
if model_config_path:
|
||||
if Path(model_config_path).is_relative_to(legacy_conf_path):
|
||||
new_config_path = Path(model_config_path).relative_to(legacy_conf_path)
|
||||
elif Path(model_config_path).is_relative_to(legacy_conf_dir):
|
||||
new_config_path = Path(*Path(model_config_path).parts[1:])
|
||||
else:
|
||||
new_config_path = None
|
||||
if new_config_path:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.config_path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_config_path), model_id),
|
||||
)
|
||||
|
||||
|
||||
def build_migration_8(app_config: InvokeAIAppConfig) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 7 to 8.
|
||||
|
||||
This migration does the following:
|
||||
- Removes the `model_config` table.
|
||||
- Migrates absolute model & legacy config paths to be relative to the models directory.
|
||||
"""
|
||||
migration_8 = Migration(
|
||||
from_version=7,
|
||||
to_version=8,
|
||||
callback=Migration8Callback(app_config),
|
||||
)
|
||||
|
||||
return migration_8
|
||||
@@ -0,0 +1,29 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration9Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._empty_session_queue(cursor)
|
||||
|
||||
def _empty_session_queue(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas."""
|
||||
|
||||
cursor.execute("DELETE FROM session_queue;")
|
||||
|
||||
|
||||
def build_migration_9() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 8 to 9.
|
||||
|
||||
This migration does the following:
|
||||
- Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas.
|
||||
"""
|
||||
migration_9 = Migration(
|
||||
from_version=8,
|
||||
to_version=9,
|
||||
callback=Migration9Callback(),
|
||||
)
|
||||
|
||||
return migration_9
|
||||
@@ -1,4 +1,6 @@
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -32,6 +34,7 @@ class SqliteMigrator:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
@@ -55,6 +58,18 @@ class SqliteMigrator:
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
@@ -91,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
class IPAdapter(RawModel):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -11,6 +11,8 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
@@ -366,7 +368,7 @@ class IA3Layer(LoRALayerBase):
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(torch.nn.Module):
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Tuple, Union
|
||||
|
||||
from diffusers.loaders.lora import LoraLoaderMixin
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils.peft_utils import recurse_remove_peft_layers
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
class LoraModelPatcher:
|
||||
@classmethod
|
||||
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
|
||||
"""Unload all LoRA models from a UNet or Text Encoder.
|
||||
This implementation is base on LoraLoaderMixin.unload_lora_weights().
|
||||
"""
|
||||
recurse_remove_peft_layers(m)
|
||||
if hasattr(m, "peft_config"):
|
||||
del m.peft_config # type: ignore
|
||||
if hasattr(m, "_hf_peft_config_loaded"):
|
||||
m._hf_peft_config_loaded = None # type: ignore
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
|
||||
try:
|
||||
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
|
||||
for lora, lora_weight in loras:
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
state_dict=lora.state_dict,
|
||||
network_alphas=lora.network_alphas,
|
||||
unet=unet,
|
||||
low_cpu_mem_usage=True,
|
||||
adapter_name=lora.name,
|
||||
_pipeline=None,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
cls.unload_lora_from_model(unet)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_to_text_encoder(
|
||||
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
|
||||
):
|
||||
assert prefix in ["text_encoder", "text_encoder_2"]
|
||||
try:
|
||||
for lora, lora_weight in loras:
|
||||
# Filter the state_dict to only include the keys that start with the prefix.
|
||||
text_encoder_state_dict = {
|
||||
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
|
||||
}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
state_dict=text_encoder_state_dict,
|
||||
network_alphas=lora.network_alphas,
|
||||
text_encoder=text_encoder,
|
||||
low_cpu_mem_usage=True,
|
||||
adapter_name=lora.name,
|
||||
_pipeline=None,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
cls.unload_lora_from_model(text_encoder)
|
||||
@@ -1,66 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.loaders.lora import LoraLoaderMixin
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class LoRAModelRaw:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
network_alphas: Optional[dict[str, float]],
|
||||
):
|
||||
self._name = name
|
||||
self.state_dict = state_dict
|
||||
self.network_alphas = network_alphas
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for key, layer in self.state_dict.items():
|
||||
self.state_dict[key] = layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Calculate the size of the model in bytes."""
|
||||
model_size = 0
|
||||
for layer in self.state_dict.values():
|
||||
model_size += layer.numel() * layer.element_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> Self:
|
||||
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
|
||||
|
||||
file_path = Path(file_path)
|
||||
if file_path.is_dir():
|
||||
raise NotImplementedError("LoRA models from directories are not yet supported.")
|
||||
|
||||
dir_path = file_path.parent
|
||||
file_name = file_path.name
|
||||
|
||||
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
|
||||
)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
model = cls(
|
||||
# TODO(ryand): Handle both files and directories here?
|
||||
name=Path(file_path).stem,
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
)
|
||||
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
model.to(device=device, dtype=dtype)
|
||||
return model
|
||||
@@ -33,42 +33,3 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
||||
@@ -31,13 +31,12 @@ from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
@@ -15,11 +15,14 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: Dict[str, torch.Tensor],
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
@@ -28,16 +31,21 @@ def convert_ldm_vae_to_diffusers(
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae.to(precision)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@@ -47,18 +55,20 @@ def convert_ckpt_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@@ -68,4 +78,6 @@ def convert_controlnet_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
||||
|
||||
@@ -19,11 +19,20 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
return self._cache_path / key
|
||||
|
||||
@@ -83,3 +83,15 @@ class ModelLoaderBase(ABC):
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -3,14 +3,13 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
@@ -54,51 +53,43 @@ class ModelLoader(ModelLoaderBase):
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_base = self._app_config.models_path
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
|
||||
# This is where the model is actually loaded!
|
||||
with skip_torch_weight_init():
|
||||
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@@ -123,15 +114,34 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -122,6 +122,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -157,8 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@@ -405,6 +411,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@@ -33,7 +35,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
@@ -44,8 +46,8 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.root_path / config.config_path, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
@@ -53,4 +55,4 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
return result
|
||||
|
||||
@@ -10,13 +10,14 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -28,14 +29,15 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_path = Path(config.path)
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
try:
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
|
||||
except OSError as e:
|
||||
|
||||
@@ -9,13 +9,14 @@ import torch
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@@ -24,13 +25,13 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
||||
@@ -3,16 +3,15 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -41,26 +40,24 @@ class LoRALoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
@@ -72,5 +69,4 @@ class LoRALoader(ModelLoader):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
return model_path.resolve()
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -25,18 +25,19 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = getattr(config, "repo_variant", None)
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -9,12 +9,16 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@@ -41,14 +45,15 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
@@ -78,7 +83,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
@@ -94,11 +99,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
|
||||
convert_ckpt_to_diffusers(
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.root_path / config.config_path,
|
||||
original_config_file=self._app_config.legacy_conf_path / config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
@@ -108,4 +113,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return output_path
|
||||
return loaded_model
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -27,22 +26,19 @@ class TextualInversionLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
file_path=config.path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
@@ -53,4 +49,4 @@ class TextualInversionLoader(ModelLoader):
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
||||
return path
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@@ -13,7 +14,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@@ -38,13 +39,13 @@ class VAELoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.root_path / config.config_path
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
@@ -63,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
return vae_model
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
from .lora_model_raw import LoRAModelRaw
|
||||
from .lora import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
"""
|
||||
|
||||
@@ -6,16 +6,17 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||
# I have not yet gone through and fixed all the type hints
|
||||
class IAIOnnxRuntimeModel(torch.nn.Module):
|
||||
class IAIOnnxRuntimeModel(RawModel):
|
||||
class _tensor_access:
|
||||
def __init__(self, model): # type: ignore
|
||||
self.model = model
|
||||
|
||||
15
invokeai/backend/raw_model.py
Normal file
15
invokeai/backend/raw_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
|
||||
class RawModel:
|
||||
"""Base class for 'Raw' model wrappers."""
|
||||
@@ -28,6 +28,10 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||
try:
|
||||
|
||||
@@ -9,8 +9,10 @@ from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
class TextualInversionModelRaw(torch.nn.Module):
|
||||
|
||||
class TextualInversionModelRaw(RawModel):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ class ConfigMapper:
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
DEFAULT_OUTDIR = "outputs"
|
||||
DEFAULT_DB_DIR = "databases"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
@@ -50,12 +53,18 @@ class ConfigMapper:
|
||||
def __load_from_root_config(self, invoke_root):
|
||||
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if not os.path.exists(yaml_path):
|
||||
print(f"Unable to find invokeai.yaml at {yaml_path}!")
|
||||
return False
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
|
||||
|
||||
if db_dir is None or outdir is None:
|
||||
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
|
||||
return False
|
||||
if db_dir is None:
|
||||
db_dir = self.DEFAULT_DB_DIR
|
||||
print(f"The invokeai.yaml file was found but is missing the db_dir setting! Defaulting to {db_dir}")
|
||||
if outdir is None:
|
||||
outdir = self.DEFAULT_OUTDIR
|
||||
print(f"The invokeai.yaml file was found but is missing the outdir setting! Defaulting to {outdir}")
|
||||
|
||||
if os.path.isabs(db_dir):
|
||||
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
|
||||
@@ -94,6 +94,7 @@
|
||||
"reactflow": "^11.10.4",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"rfdc": "^1.3.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.7.5",
|
||||
|
||||
7
invokeai/frontend/web/pnpm-lock.yaml
generated
7
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -137,6 +137,9 @@ dependencies:
|
||||
redux-remember:
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0(redux@5.0.1)
|
||||
rfdc:
|
||||
specifier: ^1.3.1
|
||||
version: 1.3.1
|
||||
roarr:
|
||||
specifier: ^7.21.1
|
||||
version: 7.21.1
|
||||
@@ -12128,6 +12131,10 @@ packages:
|
||||
resolution: {integrity: sha512-/x8uIPdTafBqakK0TmPNJzgkLP+3H+yxpUJhCQHsLBg1rYEVNR2D8BRYNWQhVBjyOd7oo1dZRVzIkwMY2oqfYQ==}
|
||||
dev: true
|
||||
|
||||
/rfdc@1.3.1:
|
||||
resolution: {integrity: sha512-r5a3l5HzYlIC68TpmYKlxWjmOP6wiPJ1vWv2HeLhNsRZMrCkxeqxiHlQ21oXmQ4F3SiryXBHhAD7JZqvOJjFmg==}
|
||||
dev: false
|
||||
|
||||
/rimraf@2.6.3:
|
||||
resolution: {integrity: sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==}
|
||||
hasBin: true
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"reportBugLabel": "Fehler melden",
|
||||
"settingsLabel": "Einstellungen",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten Editor",
|
||||
"nodes": "Arbeitsabläufe",
|
||||
"upload": "Hochladen",
|
||||
"load": "Laden",
|
||||
"statusDisconnected": "Getrennt",
|
||||
@@ -74,7 +74,8 @@
|
||||
"updated": "Aktualisiert",
|
||||
"copy": "Kopieren",
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen"
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -104,11 +105,16 @@
|
||||
"dropToUpload": "$t(gallery.drop) zum hochladen",
|
||||
"dropOrUpload": "$t(gallery.drop) oder hochladen",
|
||||
"drop": "Ablegen",
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder"
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder",
|
||||
"bulkDownloadRequested": "Download vorbereiten",
|
||||
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
||||
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
||||
"bulkDownloadFailed": "Download fehlgeschlagen",
|
||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tastenkürzel",
|
||||
"appHotkeys": "App-Tastenkombinationen",
|
||||
"appHotkeys": "App",
|
||||
"generalHotkeys": "Allgemein",
|
||||
"galleryHotkeys": "Galerie",
|
||||
"unifiedCanvasHotkeys": "Leinwand",
|
||||
@@ -757,7 +763,9 @@
|
||||
"scheduler": "Planer",
|
||||
"noRecallParameters": "Es wurden keine Parameter zum Abrufen gefunden",
|
||||
"recallParameters": "Parameter wiederherstellen",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Alle Prompts",
|
||||
"imageDimensions": "Bilder Auslösungen"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@@ -1068,5 +1076,10 @@
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"showDynamicPrompts": "Dynamische Prompts anzeigen"
|
||||
},
|
||||
"prompt": {
|
||||
"noMatchingTriggers": "Keine passenden Auslöser",
|
||||
"addPromptTrigger": "Auslöse Text hinzufügen",
|
||||
"compatibleEmbeddings": "Kompatible Einbettungen"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,8 @@
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi"
|
||||
"add": "Aggiungi",
|
||||
"loglevel": "Livello di log"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -934,7 +935,9 @@
|
||||
"base": "Base",
|
||||
"lineart": "Linea",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"mediapipeFace": "Mediapipe Volto"
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
@@ -1490,7 +1493,8 @@
|
||||
"title": "Generazione"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Avanzate"
|
||||
"title": "Avanzate",
|
||||
"options": "Opzioni $t(accordions.advanced.title)"
|
||||
},
|
||||
"image": {
|
||||
"title": "Immagine"
|
||||
|
||||
@@ -75,7 +75,8 @@
|
||||
"copy": "Копировать",
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить"
|
||||
"add": "Добавить",
|
||||
"loglevel": "Уровень логов"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@@ -1505,7 +1506,8 @@
|
||||
"title": "Генерация"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Расширенные"
|
||||
"title": "Расширенные",
|
||||
"options": "Опции $t(accordions.advanced.title)"
|
||||
},
|
||||
"image": {
|
||||
"title": "Изображение"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@@ -33,7 +33,7 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = cloneDeep(action);
|
||||
const sanitized = deepClone(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { merge } from 'lodash-es';
|
||||
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
|
||||
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
|
||||
|
||||
@@ -22,7 +23,7 @@ export const getOverlayScrollbarsParams = (
|
||||
overflowX: 'hidden' | 'scroll' = 'hidden',
|
||||
overflowY: 'hidden' | 'scroll' = 'scroll'
|
||||
) => {
|
||||
const params = cloneDeep(overlayScrollbarsParams);
|
||||
const params = deepClone(overlayScrollbarsParams);
|
||||
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
|
||||
return params;
|
||||
};
|
||||
|
||||
15
invokeai/frontend/web/src/common/util/deepClone.ts
Normal file
15
invokeai/frontend/web/src/common/util/deepClone.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import rfdc from 'rfdc';
|
||||
const _rfdc = rfdc();
|
||||
|
||||
/**
|
||||
* Deep-clones an object using Really Fast Deep Clone.
|
||||
* This is the fastest deep clone library on Chrome, but not the fastest on FF. Still, it's much faster than lodash
|
||||
* and structuredClone, so it's the best all-around choice.
|
||||
*
|
||||
* Simple Benchmark: https://www.measurethat.net/Benchmarks/Show/30358/0/lodash-clonedeep-vs-jsonparsejsonstringify-vs-recursive
|
||||
* Repo: https://github.com/davidmarkclements/rfdc
|
||||
*
|
||||
* @param obj The object to deep-clone
|
||||
* @returns The cloned object
|
||||
*/
|
||||
export const deepClone = <T>(obj: T): T => _rfdc(obj);
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
|
||||
import calculateScale from 'features/canvas/util/calculateScale';
|
||||
@@ -13,7 +14,7 @@ import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import type { PayloadActionWithOptimalDimension } from 'features/parameters/store/types';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import { clamp } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -36,7 +37,7 @@ import { CANVAS_GRID_SIZE_FINE } from './constants';
|
||||
/**
|
||||
* The maximum history length to keep in the past/future layer states.
|
||||
*/
|
||||
const MAX_HISTORY = 128;
|
||||
const MAX_HISTORY = 100;
|
||||
|
||||
const initialLayerState: CanvasLayerState = {
|
||||
objects: [],
|
||||
@@ -121,7 +122,7 @@ export const canvasSlice = createSlice({
|
||||
state.brushSize = action.payload;
|
||||
},
|
||||
clearMask: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState.objects = state.layerState.objects.filter((obj) => !isCanvasMaskLine(obj));
|
||||
state.futureLayerStates = [];
|
||||
state.shouldPreserveMaskedArea = false;
|
||||
@@ -163,10 +164,10 @@ export const canvasSlice = createSlice({
|
||||
state.boundingBoxDimensions = newBoundingBoxDimensions;
|
||||
state.boundingBoxCoordinates = newBoundingBoxCoordinates;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState = {
|
||||
...cloneDeep(initialLayerState),
|
||||
...deepClone(initialLayerState),
|
||||
objects: [
|
||||
{
|
||||
kind: 'image',
|
||||
@@ -261,11 +262,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.stagingArea.images.push({
|
||||
kind: 'image',
|
||||
@@ -279,13 +276,9 @@ export const canvasSlice = createSlice({
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
discardStagedImages: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea = cloneDeep(cloneDeep(initialLayerState)).stagingArea;
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@@ -294,11 +287,7 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
if (!images.length) {
|
||||
return;
|
||||
@@ -320,11 +309,7 @@ export const canvasSlice = createSlice({
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'fillRect',
|
||||
@@ -339,11 +324,7 @@ export const canvasSlice = createSlice({
|
||||
addEraseRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions } = state;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'eraseRect',
|
||||
@@ -367,11 +348,7 @@ export const canvasSlice = createSlice({
|
||||
// set & then spread this to only conditionally add the "color" key
|
||||
const newColor = layer === 'base' && tool === 'brush' ? { color: brushColor } : {};
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
const newLine: CanvasMaskLine | CanvasBaseLine = {
|
||||
kind: 'line',
|
||||
@@ -409,11 +386,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
||||
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates.pop();
|
||||
}
|
||||
pushToFutureLayerStates(state);
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@@ -424,11 +397,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@@ -445,8 +414,8 @@ export const canvasSlice = createSlice({
|
||||
state.shouldShowIntermediates = action.payload;
|
||||
},
|
||||
resetCanvas: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
state.layerState = cloneDeep(initialLayerState);
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState = deepClone(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
state.boundingBoxCoordinates = {
|
||||
@@ -540,11 +509,7 @@ export const canvasSlice = createSlice({
|
||||
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
const imageToCommit = images[selectedImageIndex];
|
||||
|
||||
@@ -553,7 +518,7 @@ export const canvasSlice = createSlice({
|
||||
...imageToCommit,
|
||||
});
|
||||
}
|
||||
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@@ -623,7 +588,7 @@ export const canvasSlice = createSlice({
|
||||
};
|
||||
},
|
||||
setMergedCanvas: (state, action: PayloadAction<CanvasImage>) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
|
||||
@@ -743,3 +708,17 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
migrate: migrateCanvasState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
const pushToPrevLayerStates = (state: CanvasState) => {
|
||||
state.pastLayerStates.push(deepClone(state.layerState));
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates = state.pastLayerStates.slice(-MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
||||
const pushToFutureLayerStates = (state: CanvasState) => {
|
||||
state.futureLayerStates.unshift(deepClone(state.layerState));
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2,10 +2,11 @@ import type { PayloadAction, Update } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import { merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@@ -114,7 +115,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
if (!controlAdapter) {
|
||||
return;
|
||||
}
|
||||
const newControlAdapter = merge(cloneDeep(controlAdapter), {
|
||||
const newControlAdapter = merge(deepClone(controlAdapter), {
|
||||
id: newId,
|
||||
isEnabled: true,
|
||||
});
|
||||
@@ -270,7 +271,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = merge(cloneDeep(cn.processorNode), params);
|
||||
const processorNode = merge(deepClone(cn.processorNode), params);
|
||||
|
||||
caAdapter.updateOne(state, {
|
||||
id,
|
||||
@@ -293,7 +294,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = cloneDeep(
|
||||
const processorNode = deepClone(
|
||||
CONTROLNET_PROCESSORS[processorType].buildDefaults(cn.model?.base)
|
||||
) as RequiredControlAdapterProcessorNode;
|
||||
|
||||
@@ -333,7 +334,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdaptersReset: () => {
|
||||
return cloneDeep(initialControlAdaptersState);
|
||||
return deepClone(initialControlAdaptersState);
|
||||
},
|
||||
pendingControlImagesCleared: (state) => {
|
||||
state.pendingControlImages = [];
|
||||
@@ -406,7 +407,7 @@ const migrateControlAdaptersState = (state: any): any => {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state = cloneDeep(initialControlAdaptersState);
|
||||
state = deepClone(initialControlAdaptersState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
@@ -7,7 +8,7 @@ import type {
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
import { merge } from 'lodash-es';
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
|
||||
type: 'controlnet',
|
||||
@@ -57,11 +58,11 @@ export const buildControlAdapter = (
|
||||
): ControlAdapterConfig => {
|
||||
switch (type) {
|
||||
case 'controlnet':
|
||||
return merge(cloneDeep(initialControlNet), { id, ...overrides });
|
||||
return merge(deepClone(initialControlNet), { id, ...overrides });
|
||||
case 't2i_adapter':
|
||||
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
|
||||
return merge(deepClone(initialT2IAdapter), { id, ...overrides });
|
||||
case 'ip_adapter':
|
||||
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
|
||||
return merge(deepClone(initialIPAdapter), { id, ...overrides });
|
||||
default:
|
||||
throw new Error(`Unknown control adapter type: ${type}`);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
@@ -58,7 +58,7 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
lorasReset: () => cloneDeep(initialLoraState),
|
||||
lorasReset: () => deepClone(initialLoraState),
|
||||
},
|
||||
});
|
||||
|
||||
@@ -74,7 +74,7 @@ const migrateLoRAState = (state: any): any => {
|
||||
}
|
||||
if (state._version === 1) {
|
||||
// Model type has changed, so we need to reset the state - too risky to migrate
|
||||
state = cloneDeep(initialLoraState);
|
||||
state = deepClone(initialLoraState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@@ -44,7 +45,7 @@ import {
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type {
|
||||
Connection,
|
||||
Edge,
|
||||
@@ -571,8 +572,23 @@ export const nodesSlice = createSlice({
|
||||
);
|
||||
},
|
||||
selectionCopied: (state) => {
|
||||
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
|
||||
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
|
||||
const nodesToCopy: AnyNode[] = [];
|
||||
const edgesToCopy: Edge[] = [];
|
||||
|
||||
for (const node of state.nodes) {
|
||||
if (node.selected) {
|
||||
nodesToCopy.push(deepClone(node));
|
||||
}
|
||||
}
|
||||
|
||||
for (const edge of state.edges) {
|
||||
if (edge.selected) {
|
||||
edgesToCopy.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
state.nodesToCopy = nodesToCopy;
|
||||
state.edgesToCopy = edgesToCopy;
|
||||
|
||||
if (state.nodesToCopy.length > 0) {
|
||||
const averagePosition = { x: 0, y: 0 };
|
||||
@@ -594,11 +610,21 @@ export const nodesSlice = createSlice({
|
||||
},
|
||||
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
|
||||
const { cursorPosition } = action.payload;
|
||||
const newNodes = state.nodesToCopy.map(cloneDeep);
|
||||
const newNodes: AnyNode[] = [];
|
||||
|
||||
for (const node of state.nodesToCopy) {
|
||||
newNodes.push(deepClone(node));
|
||||
}
|
||||
|
||||
const oldNodeIds = newNodes.map((n) => n.data.id);
|
||||
const newEdges = state.edgesToCopy
|
||||
.filter((e) => oldNodeIds.includes(e.source) && oldNodeIds.includes(e.target))
|
||||
.map(cloneDeep);
|
||||
|
||||
const newEdges: Edge[] = [];
|
||||
|
||||
for (const edge of state.edgesToCopy) {
|
||||
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
|
||||
newEdges.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
newEdges.forEach((e) => (e.selected = true));
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
@@ -11,7 +12,7 @@ import type {
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
|
||||
import { isEqual, omit, uniqBy } from 'lodash-es';
|
||||
|
||||
const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
name: '',
|
||||
@@ -131,8 +132,8 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
|
||||
return {
|
||||
...cloneDeep(initialWorkflowState),
|
||||
...cloneDeep(workflowExtra),
|
||||
...deepClone(initialWorkflowState),
|
||||
...deepClone(workflowExtra),
|
||||
originalExposedFieldValues,
|
||||
mode: state.mode,
|
||||
};
|
||||
@@ -144,7 +145,7 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
|
||||
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
||||
|
||||
builder.addCase(nodesChanged, (state, action) => {
|
||||
// Not all changes to nodes should result in the workflow being marked touched
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { zParsedSemver } from 'features/nodes/types/semver';
|
||||
import { cloneDeep, defaultsDeep, keys, pick } from 'lodash-es';
|
||||
import { defaultsDeep, keys, pick } from 'lodash-es';
|
||||
|
||||
import { buildInvocationNode } from './buildInvocationNode';
|
||||
|
||||
@@ -50,7 +51,7 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
|
||||
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
|
||||
// being valid. We rely on the template's major version to be majorly incremented if this kind of
|
||||
// merge would result in an invalid node.
|
||||
const clone = cloneDeep(node);
|
||||
const clone = deepClone(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults); // mutates!
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { cloneDeep, pick } from 'lodash-es';
|
||||
import { pick } from 'lodash-es';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
export type BuildWorkflowArg = {
|
||||
@@ -30,7 +31,7 @@ const workflowKeys = [
|
||||
type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
|
||||
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
|
||||
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
|
||||
const clonedWorkflow = pick(deepClone(workflow), workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
@@ -43,14 +44,14 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
data: deepClone(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
data: deepClone(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
@@ -11,7 +12,7 @@ import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@@ -89,7 +90,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
let workflow = cloneDeep(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
let workflow = deepClone(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
|
||||
if (workflow.meta.version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(workflow);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "4.0.0rc6"
|
||||
__version__ = "4.0.1"
|
||||
|
||||
@@ -44,7 +44,6 @@ dependencies = [
|
||||
"onnx==1.15.0",
|
||||
"onnxruntime==1.16.3",
|
||||
"opencv-python==4.9.0.80",
|
||||
"peft==0.9.0",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.2",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
@@ -74,7 +73,7 @@ dependencies = [
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"facexlib",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
"picklescan",
|
||||
|
||||
@@ -43,8 +43,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert model_record.path.endswith(embedding_file.as_posix())
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
@@ -77,8 +76,7 @@ def test_install(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
@@ -147,10 +145,7 @@ def test_background_install(
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path.endswith(destination)
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.key != "<NOKEY>"
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
assert model_record.description == description
|
||||
@@ -172,7 +167,7 @@ def test_not_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) != embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
assert (mm2_app_config.models_path / job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_inplace_install(
|
||||
@@ -184,16 +179,21 @@ def test_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) == embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
def test_delete_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert not Path(model_record.path).exists() # after deletion, installed copy should not exist
|
||||
assert not (
|
||||
mm2_app_config.models_path / model_record.path
|
||||
).exists() # after deletion, installed copy should not exist
|
||||
assert embedding_file.exists() # but original should still be there
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
@@ -232,7 +232,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
@@ -261,7 +261,7 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
|
||||
@@ -98,6 +98,32 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_conf_dir,expected_value,expected_is_set",
|
||||
[
|
||||
# not set, expected value is the default value
|
||||
("configs/stable-diffusion", Path("configs"), False),
|
||||
# not set, expected value is the default value
|
||||
("configs\\stable-diffusion", Path("configs"), False),
|
||||
# set, best-effort resolution of the path
|
||||
("partial_custom_path/stable-diffusion", Path("partial_custom_path"), True),
|
||||
# set, exact path
|
||||
("full/custom/path", Path("full/custom/path"), True),
|
||||
],
|
||||
)
|
||||
def test_migrate_v3_legacy_conf_dir_defaults(
|
||||
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
):
|
||||
"""Test reading configuration from a file."""
|
||||
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
assert config.legacy_conf_dir == expected_value
|
||||
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
|
||||
|
||||
|
||||
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test the backup of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
|
||||
@@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
# Write some data to the db to test for successful backup
|
||||
temp_cursor = db.conn.cursor()
|
||||
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.conn.commit()
|
||||
# Set up the migrator
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
# Must manually close else we get an error on Windows
|
||||
db.conn.close()
|
||||
assert original_db_path.exists()
|
||||
# We should have a backup file when we migrated a file db
|
||||
assert migrator._backup_path
|
||||
# Check that the test table exists as a proxy for successful backup
|
||||
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
|
||||
backup_db_cursor = backup_db_conn.cursor()
|
||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert backup_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user