mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 08:07:59 -05:00
Compare commits
165 Commits
lstein/fea
...
ryan/fix-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0781fdf3b0 | ||
|
|
8d7ca9c1b7 | ||
|
|
a61e0bd2dd | ||
|
|
798e73969c | ||
|
|
44f62944ee | ||
|
|
e9936c27fb | ||
|
|
3752509066 | ||
|
|
a1b7dbfa54 | ||
|
|
79640ba14e | ||
|
|
4075a81676 | ||
|
|
4d39976909 | ||
|
|
d14894b3ae | ||
|
|
6f5c5b0757 | ||
|
|
93caa23ef8 | ||
|
|
977a77f4e6 | ||
|
|
57c0fcb93d | ||
|
|
8b55900035 | ||
|
|
b1cc413bbd | ||
|
|
face94ce33 | ||
|
|
f0b1f0e5b6 | ||
|
|
390dc47db5 | ||
|
|
20d5c3a8bf | ||
|
|
134d831ebf | ||
|
|
b65ed8e8f2 | ||
|
|
93951dcf82 | ||
|
|
da05034e20 | ||
|
|
d579aefb3e | ||
|
|
5d1f6db414 | ||
|
|
f9961eceb7 | ||
|
|
10076fb1e8 | ||
|
|
d6e85e5f67 | ||
|
|
1ce459198c | ||
|
|
17d337169d | ||
|
|
1468f4d37e | ||
|
|
2b744480d6 | ||
|
|
abb8d34b56 | ||
|
|
9e664d7c58 | ||
|
|
c96ccae70b | ||
|
|
f268fe126e | ||
|
|
6109a06f04 | ||
|
|
5df2a79549 | ||
|
|
10b9088312 | ||
|
|
41f46b846b | ||
|
|
6dfc406c52 | ||
|
|
0d4b80780b | ||
|
|
15b9ece411 | ||
|
|
89fcab34d0 | ||
|
|
132289de55 | ||
|
|
9f93e9d120 | ||
|
|
b5f23292d4 | ||
|
|
a63dbb2c2d | ||
|
|
740bf80f3e | ||
|
|
dc90de600d | ||
|
|
5709f82e5f | ||
|
|
20042d99ec | ||
|
|
8fce168dc5 | ||
|
|
a7ea096b28 | ||
|
|
29eb3c8b62 | ||
|
|
071e8bcee4 | ||
|
|
68c0aa898f | ||
|
|
5120a76ce5 | ||
|
|
38a948ac9f | ||
|
|
c33111468e | ||
|
|
3e0fb45dd7 | ||
|
|
aba16085a5 | ||
|
|
14775cc9c4 | ||
|
|
c7562dd6c0 | ||
|
|
a0a0c57789 | ||
|
|
32ebf82d1a | ||
|
|
2dd172c2c6 | ||
|
|
280ec9d4b3 | ||
|
|
fde8fc7575 | ||
|
|
6dcdc87eb1 | ||
|
|
93ffcb642e | ||
|
|
4c914ef2e8 | ||
|
|
c0ad5bc4a4 | ||
|
|
8c58a180de | ||
|
|
715dd983b0 | ||
|
|
84ffd36071 | ||
|
|
9f30f1bfec | ||
|
|
bdff5c4e87 | ||
|
|
afb0651f91 | ||
|
|
66e25628c3 | ||
|
|
3a531a3c88 | ||
|
|
f01df49128 | ||
|
|
7bbe236107 | ||
|
|
719c066ac4 | ||
|
|
689dc30f87 | ||
|
|
1f22f6ae02 | ||
|
|
9c931d9ca0 | ||
|
|
e0a241fa4f | ||
|
|
6a4b4ee340 | ||
|
|
488bf21925 | ||
|
|
c9c39c02b6 | ||
|
|
5101dc4bef | ||
|
|
98c77a3ed1 | ||
|
|
4fca62680d | ||
|
|
f76282a5ff | ||
|
|
9a3b8c6fcb | ||
|
|
bd74b84cc5 | ||
|
|
dc23bebebf | ||
|
|
38b6f90c02 | ||
|
|
cd9dfefe3c | ||
|
|
b9946e50f9 | ||
|
|
06f49a30f6 | ||
|
|
e1af78c702 | ||
|
|
c5588e1ff7 | ||
|
|
07ac292680 | ||
|
|
7c032ea604 | ||
|
|
c5ee415607 | ||
|
|
fa40061eca | ||
|
|
7cafd78d6e | ||
|
|
8a43656cf9 | ||
|
|
bd3b6ca11b | ||
|
|
ceae5fe1db | ||
|
|
25067e4f0d | ||
|
|
fb0aaa3e6d | ||
|
|
c22526b9d0 | ||
|
|
c881882f73 | ||
|
|
36473fc52a | ||
|
|
b9964ecc4a | ||
|
|
051af802fe | ||
|
|
3ff2e558d9 | ||
|
|
fc187c9253 | ||
|
|
605f460c7d | ||
|
|
60d1e686d8 | ||
|
|
22704dd542 | ||
|
|
875673c9ba | ||
|
|
f604575862 | ||
|
|
80a67572f1 | ||
|
|
60ac937698 | ||
|
|
1e41949a02 | ||
|
|
5f0e330ed2 | ||
|
|
9dd779b414 | ||
|
|
fa183025ac | ||
|
|
d3c85aa91a | ||
|
|
82619602a5 | ||
|
|
196f3b721d | ||
|
|
244c28859d | ||
|
|
40ae174c41 | ||
|
|
afaebdf151 | ||
|
|
d661517d94 | ||
|
|
82a69a54ac | ||
|
|
ffc28176fe | ||
|
|
230e205541 | ||
|
|
7e94350351 | ||
|
|
c4e8549c73 | ||
|
|
350a210835 | ||
|
|
ed781dbb0c | ||
|
|
b41ea963e7 | ||
|
|
da5d105049 | ||
|
|
5301770525 | ||
|
|
d08e405017 | ||
|
|
534640ccde | ||
|
|
d5ab8cab5c | ||
|
|
4767301ad3 | ||
|
|
21d7ca45e6 | ||
|
|
020e8eb413 | ||
|
|
3d49541c09 | ||
|
|
1ef266845a | ||
|
|
a37589ca5f | ||
|
|
171a505f5e | ||
|
|
8004a0d5f5 | ||
|
|
610a1fd611 | ||
|
|
43108eec13 |
16
README.md
16
README.md
@@ -12,12 +12,24 @@
|
||||
|
||||
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
|
||||
|
||||
[Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
|
||||
Invoke is available in two editions:
|
||||
|
||||
| **Community Edition** | **Professional Edition** |
|
||||
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
|
||||
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
|
||||
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
|
||||
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
|
||||
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
|
||||
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
# Documentation
|
||||
| **Quick Links** |
|
||||
|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
|
||||
|
||||
</div>
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -73,15 +73,6 @@ model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
imported by, and can be reexported from,
|
||||
`invokeai.app.services.model_manager.model_records`:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
|
||||
```
|
||||
|
||||
The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
@@ -1328,7 +1319,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
|
||||
@@ -118,15 +118,13 @@ async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all()
|
||||
return ApiDependencies.invoker.services.boards.get_all(include_archived)
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
||||
@@ -9,9 +9,14 @@ from PIL import Image
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@@ -316,16 +321,14 @@ async def list_image_dtos(
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
@@ -19,7 +19,6 @@ from typing_extensions import Annotated
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
@@ -30,7 +29,6 @@ from invokeai.backend.model_manager.config import (
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
@@ -174,18 +172,6 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
path: str = Field(description="Path to the model")
|
||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||
@@ -746,39 +732,36 @@ async def convert_model(
|
||||
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
|
||||
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
|
||||
converted_model = loader.load_model(model_config)
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
assert convert_path.exists()
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
assert cache_path.exists()
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
cache_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
store.update_model(key, changes=ModelRecordChanges(name=original_name))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# Update the model image if the model had one
|
||||
try:
|
||||
@@ -791,8 +774,8 @@ async def convert_model(
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
# delete the cached version
|
||||
shutil.rmtree(cache_path)
|
||||
# delete the temporary directory
|
||||
# shutil.rmtree(cache_path)
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config = store.get_model(new_key)
|
||||
|
||||
@@ -103,7 +103,6 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -118,7 +117,6 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
@@ -205,7 +203,6 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -316,6 +313,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
@@ -11,9 +10,6 @@ factor is hard-coded to a literal '8' rather than using this constant.
|
||||
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
"""
|
||||
|
||||
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
"""A literal type representing the valid scheduler names."""
|
||||
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
"""A literal type for PIL image modes supported by Invoke"""
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import copy
|
||||
import inspect
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
@@ -18,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
@@ -54,8 +53,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -66,6 +66,9 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@@ -183,8 +186,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
@@ -194,8 +197,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
@@ -203,8 +207,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
@@ -226,11 +231,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
assert isinstance(resized_mask, torch.Tensor)
|
||||
return resized_mask
|
||||
|
||||
@staticmethod
|
||||
def _concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
@@ -280,7 +284,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||
mask, latent_height, latent_width, dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
@@ -302,36 +308,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
@staticmethod
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
) -> TextConditioningData:
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||
cond_list = positive_conditioning_field
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = self.negative_conditioning
|
||||
uncond_list = negative_conditioning_field
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
@@ -339,23 +350,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
if isinstance(self.cfg_scale, list):
|
||||
assert (
|
||||
len(self.cfg_scale) == self.steps
|
||||
), "cfg_scale (list) must have the same length as the number of steps"
|
||||
if isinstance(cfg_scale, list):
|
||||
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
guidance_scale=cfg_scale,
|
||||
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
@@ -378,38 +387,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
if control_input is None:
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
) -> list[ControlNetData] | None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
control_list = None
|
||||
if control_list is None:
|
||||
return None
|
||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
if len(control_list) == 0:
|
||||
return None
|
||||
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
_, _, latent_height, latent_width = latents_shape
|
||||
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.images.get_pil(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
@@ -430,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, # model object
|
||||
model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
@@ -584,15 +593,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@staticmethod
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
@@ -618,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
@@ -640,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if isinstance(scheduler, TCDScheduler):
|
||||
scheduler_step_kwargs.update({"eta": 1.0})
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
return timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@@ -657,31 +665,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed = None
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
||||
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
|
||||
function handles preparing these values accordingly.
|
||||
|
||||
Expected workflows:
|
||||
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
|
||||
- Image-to-Image Denoising: `noise` and `latents` are both provided.
|
||||
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
|
||||
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
|
||||
I haven't considered.
|
||||
"""
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
if noise_field is not None:
|
||||
noise = context.tensors.load(noise_field.latents_name)
|
||||
|
||||
if latents_field is not None:
|
||||
latents = context.tensors.load(latents_field.latents_name)
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise Exception("'latents' or 'noise' must be provided!")
|
||||
raise ValueError("'latents' or 'noise' must be provided!")
|
||||
|
||||
if seed is None:
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||
seed = 0
|
||||
if noise_field is not None and noise_field.seed is not None:
|
||||
seed = noise_field.seed
|
||||
elif latents_field is not None and latents_field.seed is not None:
|
||||
seed = latents_field.seed
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
@@ -707,7 +736,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# The image prompts are then passed to prep_ip_adapter_data().
|
||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
@@ -755,7 +784,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
@@ -777,7 +814,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
@@ -794,8 +831,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
is_gradient_mask=gradient_mask,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
|
||||
@@ -160,6 +160,8 @@ class FieldDescriptions:
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
|
||||
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
from functools import singledispatchmethod
|
||||
|
||||
import einops
|
||||
@@ -12,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -22,8 +23,9 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -31,7 +33,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.1.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@@ -44,12 +46,17 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def vae_encode(
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -81,9 +88,18 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if tile_size > 0:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=tile_size,
|
||||
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), tiling_context:
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
@@ -101,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
@@ -8,10 +10,9 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -32,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.2.2",
|
||||
version="1.3.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -46,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -53,9 +58,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -87,10 +92,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if self.tile_size > 0:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=self.tile_size,
|
||||
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
@@ -7,6 +6,7 @@ from invokeai.app.invocations.fields import (
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
import copy
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
MultiDiffusionRegionConditioning,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_min_overlap,
|
||||
)
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
|
||||
"""Crop a ControlNetData object to a region."""
|
||||
# Create a shallow copy of the control_data object.
|
||||
control_data_copy = copy.copy(control_data)
|
||||
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||
control_data_copy.image_tensor = control_data.image_tensor[
|
||||
:,
|
||||
:,
|
||||
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
|
||||
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
|
||||
]
|
||||
return control_data_copy
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_multi_diffusion_denoise_latents",
|
||||
title="Tiled Multi-Diffusion Denoise Latents",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
classification=Classification.Beta,
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
"""Tiled Multi-Diffusion denoising.
|
||||
|
||||
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
|
||||
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
|
||||
different parameters for each region to harness the full power of Multi-Diffusion.
|
||||
|
||||
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
|
||||
T2I-Adapter, masking, etc.).
|
||||
"""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tile_height: int = InputField(
|
||||
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
|
||||
)
|
||||
tile_width: int = InputField(
|
||||
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
|
||||
)
|
||||
tile_overlap: int = InputField(
|
||||
default=32,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
|
||||
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
|
||||
"amount.",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
control: ControlField | list[ControlField] | None = InputField(
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
) -> MultiDiffusionPipeline:
|
||||
# TODO(ryand): Get rid of this FakeVae hack.
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return MultiDiffusionPipeline(
|
||||
vae=FakeVae(),
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# Convert tile image-space dimensions to latent-space dimensions.
|
||||
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
|
||||
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
|
||||
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
|
||||
|
||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
tile_height=latent_tile_height,
|
||||
tile_width=latent_tile_width,
|
||||
min_overlap=latent_tile_overlap,
|
||||
)
|
||||
|
||||
# Get the unet's config so that we can pass the base to sd_step_callback().
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_tile_height,
|
||||
latent_width=latent_tile_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=list(latents.shape),
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# Split the controlnet_data into tiles.
|
||||
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||
for tile in tiles:
|
||||
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||
controlnet_data_tiles.append(tile_controlnet_data)
|
||||
|
||||
# Prepare the MultiDiffusionRegionConditioning list.
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
|
||||
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||
multi_diffusion_conditioning.append(
|
||||
MultiDiffusionRegionConditioning(
|
||||
region=tile,
|
||||
text_conditioning_data=conditioning_data,
|
||||
control_data=tile_controlnet_data,
|
||||
)
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Run Multi-Diffusion denoising.
|
||||
result_latents = pipeline.multi_diffusion_denoise(
|
||||
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||
target_overlap=latent_tile_overlap,
|
||||
latents=latents,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
result_latents = result_latents.to("cpu")
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
@@ -40,16 +40,12 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
@@ -22,6 +22,8 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
|
||||
"""The name of the cover image of the board."""
|
||||
archived: bool = Field(description="Whether or not the board is archived.")
|
||||
"""Whether or not the board is archived."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
@@ -35,6 +37,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
archived = board_dict.get("archived", False)
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
@@ -43,12 +46,14 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
archived=archived,
|
||||
)
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra="forbid"):
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
|
||||
@@ -125,6 +125,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@@ -134,35 +145,49 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
|
||||
# Execute query to fetch boards
|
||||
self._cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE 1=1;
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
)
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
self._cursor.execute(count_query)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
@@ -174,20 +199,25 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
"""
|
||||
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
|
||||
self._cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
@@ -44,16 +44,12 @@ class BoardServiceABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
@@ -48,8 +48,10 @@ class BoardService(BoardServiceABC):
|
||||
def delete(self, board_id: str) -> None:
|
||||
self.__invoker.services.board_records.delete(board_id)
|
||||
|
||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(offset, limit)
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
@@ -63,8 +65,8 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all()
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all(include_archived)
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
@@ -25,9 +26,8 @@ DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
@@ -85,7 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_tokenization: Enable logging of parsed prompt tokens.
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
|
||||
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
@@ -102,19 +102,16 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path to profiles output directory.
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: Amount of VRAM reserved for model storage (GB).
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
|
||||
devices: List of execution devices; will override default device selected.
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
@@ -150,7 +147,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# PATHS
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
|
||||
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
@@ -172,15 +169,13 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
|
||||
# CACHE
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
@@ -190,7 +185,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
|
||||
# NODES
|
||||
@@ -361,14 +355,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a v4.0.0.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
An `InvokeAIAppConfig` config dict.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
@@ -380,6 +374,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
@@ -399,55 +396,43 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
# When migrating the config file, we should not include currently-set environment variables.
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
parsed_config_dict["schema_version"] = "4.0.0"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
A config dict with the settings migrated to v4.0.1.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
|
||||
# precision "autocast" was replaced by "auto" in v4.0.1
|
||||
if parsed_config_dict.get("precision") == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
parsed_config_dict["schema_version"] = "4.0.1"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.1 config dictionary to a current config object.
|
||||
|
||||
A few new multi-GPU options were added in 4.0.2, and this simply
|
||||
updates the schema label.
|
||||
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
An config dict with the settings migrated to v4.0.2.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, _ in config_dict.items():
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
|
||||
# convert_cache was removed in 4.0.2
|
||||
parsed_config_dict.pop("convert_cache", None)
|
||||
parsed_config_dict["schema_version"] = "4.0.2"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
# TO DO: replace this with a formal registration and migration system
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -459,31 +444,31 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
migrated = False
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
|
||||
if loaded_config_dict["schema_version"] == "4.0.1":
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
|
||||
|
||||
if migrated:
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
# load and write without environment variables
|
||||
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
migrated_config.write_file(config_path)
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
elif loaded_config_dict["schema_version"] == "4.0.1":
|
||||
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||
|
||||
@@ -37,10 +38,13 @@ class ImageRecordStorageBase(ABC):
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .image_records_base import ImageRecordStorageBase
|
||||
@@ -144,10 +145,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -208,9 +212,21 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
@@ -116,10 +117,13 @@ class ImageServiceABC(ABC):
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@@ -73,7 +74,12 @@ class ImageService(ImageServiceABC):
|
||||
session_id=session_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
try:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
|
||||
)
|
||||
@@ -202,19 +208,25 @@ class ImageService(ImageServiceABC):
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
starred_first,
|
||||
order_dir,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
|
||||
@@ -53,11 +53,11 @@ class InvocationServices:
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
@@ -77,11 +77,11 @@ class InvocationServices:
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.performance_statistics = performance_statistics
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
|
||||
@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
self._stats.pop(graph_execution_state_id)
|
||||
self._cache_stats.pop(graph_execution_state_id)
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -284,14 +284,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
|
||||
def _migrate_yaml(self) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
if overwrite_db:
|
||||
for model in db_models:
|
||||
self.record_store.del_model(model.key)
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
legacy_models_yaml_path = (
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
@@ -341,8 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
if rename_yaml:
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
# Unset the path - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
|
||||
@@ -6,8 +6,7 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
@@ -28,16 +27,6 @@ class ModelLoadServiceBase(ABC):
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Type
|
||||
from typing import Callable, Optional
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
@@ -11,14 +11,9 @@ from torch import load as torch_load
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -33,8 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
registry: ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
@@ -42,11 +36,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._logger = logger
|
||||
self._app_config = app_config
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@@ -54,16 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
@@ -82,7 +64,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"LoadedModel",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
@@ -32,7 +31,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
execution_device: torch.device,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -65,6 +70,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -77,13 +83,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry,
|
||||
)
|
||||
installer = ModelInstallService(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
@@ -10,7 +9,6 @@ import torch
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@@ -72,10 +70,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
tid = threading.current_thread().ident
|
||||
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
@@ -27,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
|
||||
from invokeai.app.services.shared.graph import NodeInputError
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
@@ -59,11 +57,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._process_lock = Lock()
|
||||
|
||||
def start(
|
||||
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||
) -> None:
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
@@ -81,8 +76,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while True:
|
||||
try:
|
||||
with self._process_lock:
|
||||
invocation = queue_item.session.next()
|
||||
invocation = queue_item.session.next()
|
||||
# Anything other than a `NodeInputError` is handled as a processor error
|
||||
except NodeInputError as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -114,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
try:
|
||||
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
@@ -216,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@@ -330,7 +324,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._active_queue_items: Set[SessionQueueItem] = set()
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
@@ -356,14 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||
TorchDevice.execution_devices()
|
||||
)
|
||||
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||
# Session processor - singlethreaded
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@@ -376,16 +363,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
# Session processor workers - multithreaded
|
||||
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
|
||||
for _i in range(0, self._worker_thread_count):
|
||||
worker = Thread(
|
||||
name="session_worker",
|
||||
target=self._process_next_session,
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@@ -393,7 +370,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
@@ -401,7 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
@@ -426,7 +403,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=len(self._active_queue_items) > 0,
|
||||
is_processing=self._queue_item is not None,
|
||||
)
|
||||
|
||||
def _process(
|
||||
@@ -451,22 +428,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
queue_item = self._invoker.services.session_queue.dequeue()
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is None:
|
||||
if self._queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._session_worker_queue.put(queue_item)
|
||||
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
# self.session_runner.run(queue_item=self._queue_item)
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_non_fatal_processor_error(
|
||||
queue_item=self._queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
@@ -481,25 +466,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
||||
def _process_next_session(self) -> None:
|
||||
while True:
|
||||
self._resume_event.wait()
|
||||
queue_item = self._session_worker_queue.get()
|
||||
if queue_item.status == "canceled":
|
||||
continue
|
||||
try:
|
||||
self._active_queue_items.add(queue_item)
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||
# Run the session on the reserved GPU
|
||||
self.session_runner.run(queue_item=queue_item)
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
self._active_queue_items.remove(queue_item)
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
|
||||
@@ -236,9 +236,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
}
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.item_id
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
@@ -27,13 +26,11 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
@@ -326,6 +323,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
@@ -559,28 +557,6 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def torch_device(self) -> torch.device:
|
||||
"""
|
||||
Return a torch device to use in the current invocation.
|
||||
|
||||
Returns:
|
||||
A torch.device not currently in use by the system.
|
||||
"""
|
||||
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
|
||||
return ram_cache.get_execution_device()
|
||||
|
||||
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""
|
||||
Return a precision type to use with the current invocation and torch device.
|
||||
|
||||
Args:
|
||||
device: Optional device.
|
||||
|
||||
Returns:
|
||||
A torch.dtype suited for the current device.
|
||||
"""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -14,6 +14,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -45,6 +47,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.register_migration(build_migration_10())
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration12Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig) -> None:
|
||||
self._app_config = app_config
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._remove_model_convert_cache_dir()
|
||||
|
||||
def _remove_model_convert_cache_dir(self) -> None:
|
||||
"""
|
||||
Removes unused model convert cache directory
|
||||
"""
|
||||
convert_cache = self._app_config.convert_cache_path
|
||||
shutil.rmtree(convert_cache, ignore_errors=True)
|
||||
|
||||
|
||||
def build_migration_12(app_config: InvokeAIAppConfig) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 11 to 12.
|
||||
|
||||
This migration removes the now-unused model convert cache directory.
|
||||
"""
|
||||
migration_12 = Migration(
|
||||
from_version=11,
|
||||
to_version=12,
|
||||
callback=Migration12Callback(app_config),
|
||||
)
|
||||
|
||||
return migration_12
|
||||
@@ -0,0 +1,31 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration13Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_archived_col(cursor)
|
||||
|
||||
def _add_archived_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `archived` columns to the board table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE boards ADD COLUMN archived BOOLEAN DEFAULT FALSE;")
|
||||
|
||||
|
||||
def build_migration_13() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 12 to 13..
|
||||
|
||||
This migration does the following:
|
||||
- Adds `archived` columns to the board table.
|
||||
"""
|
||||
migration_13 = Migration(
|
||||
from_version=12,
|
||||
to_version=13,
|
||||
callback=Migration13Callback(),
|
||||
)
|
||||
|
||||
return migration_13
|
||||
@@ -289,7 +289,7 @@ def prepare_control_image(
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device: str = "cuda",
|
||||
device: str | torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
@@ -304,7 +304,7 @@ def prepare_control_image(
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
|
||||
@@ -11,6 +11,7 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
@@ -137,10 +138,7 @@ class IPAdapter(RawModel):
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(
|
||||
self, state_dict: dict[str, torch.Tensor]
|
||||
|
||||
@@ -10,6 +10,7 @@ from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
@@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@@ -12,7 +12,9 @@ def validate_hash(hash: str):
|
||||
map = json.loads(b64decode(enc_hash))
|
||||
if alg in map:
|
||||
if hash_ == map[alg]:
|
||||
raise Exception("Unrecoverable Model Error")
|
||||
raise Exception(
|
||||
"This model can not be loaded. If you're looking for help, consider visiting https://www.redirectionprogram.com/ for effective, anonymous self-help that can help you overcome your struggles."
|
||||
)
|
||||
|
||||
|
||||
hashes: list[str] = [
|
||||
|
||||
@@ -13,7 +13,6 @@ from .config import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .load import LoadedModel
|
||||
from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
@@ -23,7 +22,6 @@ __all__ = [
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
"InvalidModelConfigException",
|
||||
"LoadedModel",
|
||||
"ModelConfigFactory",
|
||||
"ModelFormat",
|
||||
"ModelProbe",
|
||||
|
||||
@@ -24,21 +24,21 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@@ -178,7 +178,6 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Extend the pydantic schema from a json."""
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
@@ -445,7 +444,7 @@ class ModelConfigFactory(object):
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
|
||||
#
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
download_controlnet_from_original_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
||||
@@ -1,29 +1 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Init file for the model loader.
|
||||
"""
|
||||
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||
for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
"ModelLoader",
|
||||
"ModelLoaderRegistryBase",
|
||||
"ModelLoaderRegistry",
|
||||
]
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
from .convert_cache_default import ModelConvertCache
|
||||
|
||||
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
Disk-based converted model cache.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ModelConvertCacheBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
pass
|
||||
@@ -1,83 +0,0 @@
|
||||
"""
|
||||
Placeholder for convert cache implementation.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import safe_filename
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
|
||||
class ModelConvertCache(ModelConvertCacheBase):
|
||||
def __init__(self, cache_path: Path, max_size: float = 10.0):
|
||||
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
||||
if not cache_path.exists():
|
||||
cache_path.mkdir(parents=True)
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
key = safe_filename(self._cache_path, key)
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
size_needed = directory_size(self._cache_path) + size
|
||||
max_size = int(self.max_size) * GIG
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
if size_needed <= max_size:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
|
||||
)
|
||||
|
||||
# For this to work, we make the assumption that the directory contains
|
||||
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
|
||||
# This should be true for any diffusers model.
|
||||
def by_atime(path: Path) -> float:
|
||||
for config in ["model_index.json", "unet/config.json", "config.json"]:
|
||||
sentinel = path / config
|
||||
if sentinel.exists():
|
||||
return sentinel.stat().st_atime
|
||||
|
||||
# no sentinel file found! - pick the most recent file in the directory
|
||||
try:
|
||||
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
|
||||
return atimes[0]
|
||||
except IndexError:
|
||||
return 0.0
|
||||
|
||||
# sort by last access time - least accessed files will be at the end
|
||||
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
|
||||
logger.debug(f"cached models in descending atime order: {lru_models}")
|
||||
while size_needed > max_size and len(lru_models) > 0:
|
||||
next_victim = lru_models.pop()
|
||||
victim_size = directory_size(next_victim)
|
||||
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
|
||||
shutil.rmtree(next_victim)
|
||||
size_needed -= victim_size
|
||||
@@ -0,0 +1,8 @@
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
def _build_model_loader_registry():
|
||||
return ModelLoaderRegistry()
|
||||
|
||||
|
||||
MODEL_LOADER_REGISTRY = _build_model_loader_registry()
|
||||
@@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
@@ -65,7 +64,8 @@ class LoadedModelWithoutConfig:
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
return self._locker.lock()
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
@@ -111,7 +111,6 @@ class ModelLoaderBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -137,12 +136,6 @@ class ModelLoaderBase(ABC):
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
|
||||
@@ -12,11 +12,10 @@ from invokeai.backend.model_manager import (
|
||||
InvalidModelConfigException,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.model_size_utils import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -30,13 +29,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
@@ -50,23 +47,15 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
@@ -76,20 +65,14 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
config.path = str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@@ -113,28 +96,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
|
||||
@@ -8,10 +8,9 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,13 +51,44 @@ class CacheRecord(Generic[T]):
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
size: int
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -85,27 +115,14 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@property
|
||||
@abstractmethod
|
||||
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (GPU) under the current thread id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It might be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -114,6 +131,16 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@@ -175,11 +202,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -18,19 +18,17 @@ context. Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager, suppress
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -41,7 +39,9 @@ from .model_locker import ModelLocker
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
@@ -57,8 +57,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@@ -66,19 +70,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
@@ -86,87 +94,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
devices = self._execution_devices.keys()
|
||||
return set(devices)
|
||||
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if not assigned:
|
||||
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
|
||||
return assigned[0]
|
||||
|
||||
@contextmanager
|
||||
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||
|
||||
Note that the reservation is done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
"""
|
||||
device = None
|
||||
with self._device_lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# look for a device that has already been assigned to this thread
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if assigned:
|
||||
device = assigned[0]
|
||||
|
||||
# no device already assigned. Get one.
|
||||
if device is None:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
with self._device_lock:
|
||||
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||
self._execution_devices[free_device[0]] = current_thread
|
||||
device = free_device[0]
|
||||
|
||||
# we are outside the lock region now
|
||||
self.logger.info(f"{current_thread} Reserved torch device {device}")
|
||||
|
||||
# Tell TorchDevice to use this object to get the torch device.
|
||||
TorchDevice.set_model_cache(self)
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self.logger.info(f"{current_thread} Released torch device {device}")
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@@ -211,16 +157,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -238,37 +184,36 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
@@ -280,34 +225,129 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it.
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
with self._ram_lock:
|
||||
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Some models don't have a state dictionary, in which case the
|
||||
# stored model will still reside in CPU
|
||||
if hasattr(cache_entry.model, "to"):
|
||||
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||
assert hasattr(model_in_gpu, "to")
|
||||
model_in_gpu.to(target_device)
|
||||
return model_in_gpu
|
||||
else:
|
||||
return cache_entry.model # what happens in CPU stays in CPU
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(
|
||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||
)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = len(self._cached_models)
|
||||
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
@@ -330,14 +370,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if refs <= (3 if "onnx" in model_key else 2):
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
@@ -364,26 +402,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
try:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
||||
@@ -10,8 +10,6 @@ from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@@ -31,29 +29,33 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
device = self._cache.get_execution_device()
|
||||
model_on_device = self._cache.model_to_device(self._cache_entry, device)
|
||||
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return model_on_device
|
||||
return self.model
|
||||
|
||||
# It is no longer necessary to move the model out of VRAM
|
||||
# because it will be removed when it goes out of scope
|
||||
# in the caller's context
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
# This is no longer in use in MGPU.
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return None
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
@@ -1,48 +1,34 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
"""
|
||||
This module implements a system in which model loaders register the
|
||||
type, base and format of models that they know how to load.
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
Use like this:
|
||||
|
||||
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model = cls(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
|
||||
|
||||
|
||||
class ModelLoaderRegistryBase(ABC):
|
||||
"""This class allows model loaders to register their type, base and format."""
|
||||
class ModelLoaderRegistry:
|
||||
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
|
||||
|
||||
def __init__(self):
|
||||
self._registry: dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
self,
|
||||
loader_class: Type[ModelLoaderBase],
|
||||
type: ModelType,
|
||||
format: ModelFormat,
|
||||
base: BaseModelType = BaseModelType.Any,
|
||||
):
|
||||
"""Register a model loader class."""
|
||||
key = self._to_registry_key(base, type, format)
|
||||
if key in self._registry:
|
||||
raise RuntimeError(
|
||||
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
|
||||
f"of model has already been registered by {self._registry[key].__name__}"
|
||||
)
|
||||
self._registry[key] = loader_class
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""
|
||||
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||
@@ -56,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
|
||||
in, in the event that a submodel type is provided.
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry(ModelLoaderRegistryBase):
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
|
||||
f"format={config.format}"
|
||||
)
|
||||
return implementation, config, submodel_type
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import ControlNetModel
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -11,8 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@@ -23,36 +23,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
if config.base == BaseModelType.StableDiffusion1
|
||||
else 768
|
||||
if config.base == BaseModelType.StableDiffusion2
|
||||
else 1024
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
image_size=image_size,
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
return ControlNetModel.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return super()._load_model(config, submodel_type)
|
||||
|
||||
@@ -18,8 +18,8 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
|
||||
@@ -8,7 +8,8 @@ import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
@@ -32,10 +31,9 @@ class LoRALoader(ModelLoader):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache, convert_cache)
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
self._model_base: Optional[BaseModelType] = None
|
||||
|
||||
def _load_model(
|
||||
|
||||
@@ -4,22 +4,28 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@@ -48,8 +54,12 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
if isinstance(config, CheckpointConfigBase):
|
||||
return self._load_from_singlefile(config, submodel_type)
|
||||
|
||||
if submodel_type is None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
@@ -71,46 +81,58 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
load_classes = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: StableDiffusionPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: StableDiffusionPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
try:
|
||||
load_class = load_classes[config.base][config.variant]
|
||||
except KeyError as e:
|
||||
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
image_size = (
|
||||
1024
|
||||
if base == BaseModelType.StableDiffusionXL
|
||||
else 768
|
||||
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
|
||||
else 512
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
# Without SilenceWarnings we get log messages like this:
|
||||
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
|
||||
# warnings.warn(
|
||||
# Some weights of the model checkpoint were not used when initializing CLIPTextModel:
|
||||
# ['text_model.embeddings.position_ids']
|
||||
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
|
||||
# ['text_model.embeddings.position_ids']
|
||||
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.legacy_conf_path / config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
image_size=image_size,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return loaded_model
|
||||
with SilenceWarnings():
|
||||
pipeline = load_class.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
if not submodel_type:
|
||||
return pipeline
|
||||
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-load
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@@ -14,8 +11,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@@ -26,39 +22,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, VAECheckpointConfig):
|
||||
return AutoencoderKL.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
return vae_model
|
||||
return super()._load_model(config, submodel_type)
|
||||
|
||||
79
invokeai/backend/model_manager/load/model_size_utils.py
Normal file
79
invokeai/backend/model_manager/load/model_size_utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def calc_module_size(model: torch.nn.Module) -> int:
|
||||
"""Estimate the size of a torch.nn.Module in bytes."""
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem: int = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
||||
"""Estimate the size of a model on disk in bytes."""
|
||||
if model_path.is_file():
|
||||
return model_path.stat().st_size
|
||||
|
||||
if subfolder is not None:
|
||||
model_path = model_path / subfolder
|
||||
|
||||
# this can happen when, for example, the safety checker is not downloaded.
|
||||
if not model_path.exists():
|
||||
return 0
|
||||
|
||||
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
|
||||
|
||||
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
|
||||
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
elif variant == "8bit":
|
||||
files = bit8_files
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||
|
||||
# try read from index if exists
|
||||
index_postfix = ".index.json"
|
||||
if variant is not None:
|
||||
index_postfix = f".index.{variant}.json"
|
||||
|
||||
for file in files:
|
||||
if not file.name.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(model_path / file, "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
model_files = [f for f in files if f.suffix in file_format]
|
||||
if len(model_files) == 0:
|
||||
continue
|
||||
|
||||
model_size = 0
|
||||
for model_file in model_files:
|
||||
file_stats = (model_path / model_file).stat()
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
@@ -1,14 +1,11 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development Team
|
||||
"""Various utility functions needed by the loader and caching system."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
@@ -17,7 +14,7 @@ def calc_model_size_by_data(model: AnyModel) -> int:
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
return calc_module_size(model)
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
else:
|
||||
@@ -30,84 +27,11 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
res += calc_module_size(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model: torch.nn.Module) -> int:
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem: int = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
|
||||
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||
mem = tensor_size # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
||||
"""Estimate the size of a model on disk in bytes."""
|
||||
if model_path.is_file():
|
||||
return model_path.stat().st_size
|
||||
|
||||
if subfolder is not None:
|
||||
model_path = model_path / subfolder
|
||||
|
||||
# this can happen when, for example, the safety checker is not downloaded.
|
||||
if not model_path.exists():
|
||||
return 0
|
||||
|
||||
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
|
||||
|
||||
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
|
||||
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
elif variant == "8bit":
|
||||
files = bit8_files
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||
|
||||
# try read from index if exists
|
||||
index_postfix = ".index.json"
|
||||
if variant is not None:
|
||||
index_postfix = f".index.{variant}.json"
|
||||
|
||||
for file in files:
|
||||
if not file.name.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(model_path / file, "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
model_files = [f for f in files if f.suffix in file_format]
|
||||
if len(model_files) == 0:
|
||||
continue
|
||||
|
||||
model_size = 0
|
||||
for model_file in model_files:
|
||||
file_stats = (model_path / model_file).stat()
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
@@ -312,6 +312,8 @@ class ModelProbe(object):
|
||||
config_file = (
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
else "stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -294,8 +294,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning.",
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
@@ -326,6 +326,20 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
StarterModel(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -17,6 +16,7 @@ from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .lora import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
@@ -35,8 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
_thread_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -109,7 +107,7 @@ class ModelPatcher:
|
||||
"""
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad(), cls._thread_lock:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@@ -132,7 +130,9 @@ class ModelPatcher:
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
if model_state_dict is None: # no CPU copy of the state dict was provided
|
||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||
else:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
@@ -140,12 +140,15 @@ class ModelPatcher:
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device, non_blocking=True)
|
||||
layer.to(dtype=torch.float32, non_blocking=True)
|
||||
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"), non_blocking=True)
|
||||
layer.to(
|
||||
device=TorchDevice.CPU_DEVICE,
|
||||
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
|
||||
)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
@@ -154,7 +157,7 @@ class ModelPatcher:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@@ -162,7 +165,9 @@ class ModelPatcher:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
|
||||
model.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
||||
@@ -10,12 +10,11 @@ import PIL.Image
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
mask: torch.Tensor
|
||||
mask_latents: torch.Tensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
gradient_mask: bool
|
||||
is_gradient_mask: bool
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return self.apply_mask(latents, t)
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if self.gradient_mask:
|
||||
if self.is_gradient_mask:
|
||||
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
self.use_ip_adapter = False
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
@@ -280,116 +239,128 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
|
||||
def add_inpainting_channels_to_latents(
|
||||
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||
):
|
||||
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||
|
||||
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||
- Channel 0-3: The latents being denoised.
|
||||
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||
|
||||
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
# Validate assumptions about input tensor shapes.
|
||||
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||
assert latent_channels == 4
|
||||
assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width]
|
||||
assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width]
|
||||
|
||||
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||
|
||||
# Concatenate along the channel dimension.
|
||||
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
seed: int,
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: int,
|
||||
is_gradient_mask: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Denoise the latents.
|
||||
|
||||
Args:
|
||||
latents: The latent-space image to denoise.
|
||||
- If we are inpainting, this is the initial latent image before noise has been added.
|
||||
- If we are generating a new image, this should be initialized to zeros.
|
||||
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
|
||||
conditioning_data: Text conditionging data.
|
||||
noise: Noise used for two purposes:
|
||||
1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
2. Used to noise the `masked_latents` when inpainting.
|
||||
`noise` should be None if the `latents` tensor has already been noised.
|
||||
seed: The seed used to generate the noise for the denoising process.
|
||||
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
timesteps: The timestep schedule for the denoising process.
|
||||
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||
should be populated if you want noise applied *even* if timesteps is empty.
|
||||
callback: A callback function that is called to report progress during the denoising process.
|
||||
control_data: ControlNet data.
|
||||
ip_adapter_data: IP-Adapter data.
|
||||
t2i_adapter_data: T2I-Adapter data.
|
||||
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
|
||||
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
|
||||
`latents` tensor.
|
||||
TODO(ryand): Check and document the expected dtype, range, and values used to represent
|
||||
foreground/background.
|
||||
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
|
||||
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
|
||||
SD UNet model.
|
||||
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
|
||||
"""
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
orig_latents = latents.clone()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = init_timestep.expand(batch_size)
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
if mask is not None:
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
# Handle mask guidance (a.k.a. inpainting).
|
||||
mask_guidance: AddsMaskGuidance | None = None
|
||||
if mask is not None and not is_inpainting_model(self.unet):
|
||||
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||
# apply mask guidance to the latents.
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
mask_guidance = AddsMaskGuidance(
|
||||
mask=mask,
|
||||
mask_latents=orig_latents,
|
||||
scheduler=self.scheduler,
|
||||
noise=noise,
|
||||
is_gradient_mask=is_gradient_mask,
|
||||
)
|
||||
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
@@ -402,28 +373,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
latents,
|
||||
conditioning_data,
|
||||
t=batched_t,
|
||||
latents=latents,
|
||||
conditioning_data=conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
mask_guidance=mask_guidance,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
@@ -431,19 +402,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if is_gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
@@ -454,19 +434,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
mask_guidance: AddsMaskGuidance | None,
|
||||
mask: torch.Tensor | None,
|
||||
masked_latents: torch.Tensor | None,
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# one day we will expand this extension point, but for now it just does denoise masking
|
||||
for guidance in additional_guidance:
|
||||
latents = guidance(latents, timestep)
|
||||
# Handle masked image-to-image (a.k.a inpainting).
|
||||
if mask_guidance is not None:
|
||||
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||
latents = mask_guidance(latents, timestep)
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
@@ -514,6 +495,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
if is_inpainting_model(self.unet):
|
||||
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||
# latents.
|
||||
if mask is not None:
|
||||
if masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||
)
|
||||
else:
|
||||
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||
# We generate a global mask and empty original image so that we can still generate in this
|
||||
# configuration.
|
||||
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||
# to do this.
|
||||
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||
# mask and original image once rather than on every denoising step.
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input,
|
||||
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||
)
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
@@ -542,17 +548,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||
for guidance in additional_guidance:
|
||||
# apply the mask to any "denoised" or "pred_original_sample" fields
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
|
||||
# again.
|
||||
if mask_guidance is not None:
|
||||
# Apply the mask to any "denoised" or "pred_original_sample" fields.
|
||||
if hasattr(step_output, "denoised"):
|
||||
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
elif hasattr(step_output, "pred_original_sample"):
|
||||
step_output.pred_original_sample = guidance(
|
||||
step_output.pred_original_sample = mask_guidance(
|
||||
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
||||
)
|
||||
else:
|
||||
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
|
||||
|
||||
return step_output
|
||||
|
||||
@@ -575,17 +582,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**kwargs,
|
||||
):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||
# overriding things! This should get handled in a way more consistent with the other
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(
|
||||
latents,
|
||||
|
||||
@@ -32,11 +32,8 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
assert self.pooled_embeds.device == device
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
result = super().to(device=device, dtype=dtype)
|
||||
assert self.embeds.device == device
|
||||
return result
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -294,31 +293,24 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
try:
|
||||
if conditioning_data.is_sdxl():
|
||||
# tid = threading.current_thread().ident
|
||||
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
tid = threading.current_thread().ident
|
||||
print(f"DEBUG: {tid} {str(e)}")
|
||||
raise e
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
|
||||
170
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
170
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDiffusionRegionConditioning:
|
||||
# Region coords in latent space.
|
||||
region: Tile
|
||||
text_conditioning_data: TextConditioningData
|
||||
control_data: list[ControlNetData]
|
||||
|
||||
|
||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||
|
||||
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||
"""Validate that regional conditioning is not used."""
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
if (
|
||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||
):
|
||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||
|
||||
def multi_diffusion_denoise(
|
||||
self,
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||
target_overlap: int,
|
||||
latents: torch.Tensor,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
) -> torch.Tensor:
|
||||
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
batch_size, _, latent_height, latent_width = latents.shape
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||
# cropping into regions.
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||
# separate scheduler state for each region batch.
|
||||
# TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
|
||||
# statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
|
||||
# as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
|
||||
# multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
|
||||
# scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
|
||||
region_batch_schedulers: list[SchedulerMixin] = [
|
||||
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
|
||||
]
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_latents_weights = torch.zeros(
|
||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||
)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||
# Switch to the scheduler for the region batch.
|
||||
self.scheduler = region_batch_schedulers[region_idx]
|
||||
|
||||
# Crop the inputs to the region.
|
||||
region_latents = latents[
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
|
||||
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
|
||||
]
|
||||
|
||||
# Run the denoising step on the region.
|
||||
step_output = self.step(
|
||||
t=batched_t,
|
||||
latents=region_latents,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
|
||||
# Store the results from the region.
|
||||
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
|
||||
# affected tiles to achieve the target overlap.
|
||||
region = region_conditioning.region
|
||||
top_adjustment = max(0, region.overlap.top - target_overlap)
|
||||
left_adjustment = max(0, region.overlap.left - target_overlap)
|
||||
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
|
||||
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
|
||||
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
# For now, we treat every region as having the same weight.
|
||||
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
|
||||
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
|
||||
# For debugging, uncomment this line to visualize the region seams:
|
||||
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
|
||||
predicted_original = None
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = torch.where(
|
||||
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
|
||||
)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
@@ -1,3 +0,0 @@
|
||||
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||
|
||||
__all__ = ["SCHEDULER_MAP"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -43,3 +45,9 @@ SCHEDULER_MAP = {
|
||||
"lcm": (LCMScheduler, {}),
|
||||
"tcd": (TCDScheduler, {}),
|
||||
}
|
||||
|
||||
|
||||
# HACK(ryand): Passing a tuple of keys to Literal works at runtime, but not at type-check time. See the docs here for
|
||||
# more info: https://typing.readthedocs.io/en/latest/spec/literal.html#parameters-at-runtime. For now, we are ignoring
|
||||
# this error. In the future, we should fix this type handling.
|
||||
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # type: ignore
|
||||
|
||||
35
invokeai/backend/stable_diffusion/vae_tiling.py
Normal file
35
invokeai/backend/stable_diffusion/vae_tiling.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_vae_tiling_params(
|
||||
vae: AutoencoderKL | AutoencoderTiny,
|
||||
tile_sample_min_size: int,
|
||||
tile_latent_min_size: int,
|
||||
tile_overlap_factor: float,
|
||||
):
|
||||
"""Patch the parameters that control the VAE tiling tile size and overlap.
|
||||
|
||||
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
|
||||
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
|
||||
usage.
|
||||
"""
|
||||
# Record initial config.
|
||||
orig_tile_sample_min_size = vae.tile_sample_min_size
|
||||
orig_tile_latent_min_size = vae.tile_latent_min_size
|
||||
orig_tile_overlap_factor = vae.tile_overlap_factor
|
||||
|
||||
try:
|
||||
# Apply target config.
|
||||
vae.tile_sample_min_size = tile_sample_min_size
|
||||
vae.tile_latent_min_size = tile_latent_min_size
|
||||
vae.tile_overlap_factor = tile_overlap_factor
|
||||
yield
|
||||
finally:
|
||||
# Restore initial config.
|
||||
vae.tile_sample_min_size = orig_tile_sample_min_size
|
||||
vae.tile_latent_min_size = orig_tile_latent_min_size
|
||||
vae.tile_overlap_factor = orig_tile_overlap_factor
|
||||
@@ -1,16 +1,10 @@
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@@ -48,23 +42,13 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
||||
|
||||
@classmethod
|
||||
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
||||
"""Set the current model cache."""
|
||||
cls._model_cache = cache
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.get_execution_device()
|
||||
else:
|
||||
return cls._choose_device()
|
||||
|
||||
@classmethod
|
||||
def _choose_device(cls) -> torch.device:
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
@@ -76,19 +60,11 @@ class TorchDevice:
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> Set[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.devices is None:
|
||||
return cls._lookup_execution_devices()
|
||||
return {torch.device(x) for x in app_config.devices}
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
device = device or cls._choose_device()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
@@ -137,12 +113,14 @@ class TorchDevice:
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@classmethod
|
||||
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
@staticmethod
|
||||
def get_non_blocking(to_device: torch.device) -> bool:
|
||||
"""Return the non_blocking flag to be used when moving a tensor to a given device.
|
||||
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
|
||||
When moving _from_ MPS, we can use non-blocking operations.
|
||||
|
||||
See:
|
||||
- https://github.com/pytorch/pytorch/issues/107455
|
||||
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
|
||||
"""
|
||||
return False if to_device.type == "mps" else True
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Optional, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@@ -17,7 +17,10 @@
|
||||
},
|
||||
"boards": {
|
||||
"addBoard": "Add Board",
|
||||
"archiveBoard": "Archive Board",
|
||||
"archived": "Archived",
|
||||
"autoAddBoard": "Auto-Add Board",
|
||||
"selectedForAutoAdd": "Selected for Auto-Add",
|
||||
"bottomMessage": "Deleting this board and its images will reset any features currently using them.",
|
||||
"cancel": "Cancel",
|
||||
"changeBoard": "Change Board",
|
||||
@@ -36,8 +39,13 @@
|
||||
"searchBoard": "Search Boards...",
|
||||
"selectBoard": "Select a Board",
|
||||
"topMessage": "This board contains images used in the following features:",
|
||||
"unarchiveBoard": "Unarchive Board",
|
||||
"uncategorized": "Uncategorized",
|
||||
"downloadBoard": "Download Board"
|
||||
"downloadBoard": "Download Board",
|
||||
"imagesWithCount_one": "{{count}} image",
|
||||
"imagesWithCount_other": "{{count}} images",
|
||||
"assetsWithCount_one": "{{count}} asset",
|
||||
"assetsWithCount_other": "{{count}} assets"
|
||||
},
|
||||
"accordions": {
|
||||
"generation": {
|
||||
@@ -364,6 +372,10 @@
|
||||
"image": "image",
|
||||
"loading": "Loading",
|
||||
"loadMore": "Load More",
|
||||
"newestFirst": "Newest First",
|
||||
"oldestFirst": "Oldest First",
|
||||
"sortDirection": "Sort Direction",
|
||||
"showStarredImagesFirst": "Show Starred Images First",
|
||||
"noImageSelected": "No Image Selected",
|
||||
"noImagesInGallery": "No Images to Display",
|
||||
"setCurrentImage": "Set as Current Image",
|
||||
@@ -381,6 +393,10 @@
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"searchImages": "Search by Metadata",
|
||||
"selectAllOnPage": "Select All On Page",
|
||||
"selectAllOnBoard": "Select All On Board",
|
||||
"showArchivedBoards": "Show Archived Boards",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
|
||||
@@ -23,6 +23,7 @@ import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listener
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addRequestedSingleImageDeletionListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDeleted';
|
||||
@@ -51,6 +52,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@@ -77,6 +80,7 @@ addImagesUnstarredListener(startAppListening);
|
||||
|
||||
// Gallery
|
||||
addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedCanvasListener(startAppListening);
|
||||
@@ -116,6 +120,7 @@ addControlNetAutoProcessListener(startAppListening);
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
addImageRemovedFromBoardFulfilledListener(startAppListening);
|
||||
addBoardIdSelectedListener(startAppListening);
|
||||
addArchivedOrDeletedBoardListener(startAppListening);
|
||||
|
||||
// Node schemas
|
||||
addGetOpenAPISchemaListener(startAppListening);
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
autoAddBoardIdChanged,
|
||||
boardIdSelected,
|
||||
galleryViewChanged,
|
||||
shouldShowArchivedBoardsChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: isAnyOf(
|
||||
// Updating a board may change its archived status
|
||||
boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||
// If the selected/auto-add board was deleted from a different session, we'll only know during the list request,
|
||||
boardsApi.endpoints.listAllBoards.matchFulfilled,
|
||||
// If a board is deleted, we'll need to reset the auto-add board
|
||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||
// When we change the visibility of archived boards, we may need to reset the auto-add board
|
||||
shouldShowArchivedBoardsChanged
|
||||
),
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
/**
|
||||
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
||||
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
||||
*/
|
||||
|
||||
const state = getState();
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
const autoAddBoardId = state.gallery.autoAddBoardId;
|
||||
|
||||
if (!queryResult.data) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -2,8 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageCache } from 'services/api/types';
|
||||
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
|
||||
import { getListImagesUrl } from 'services/api/util';
|
||||
|
||||
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -18,13 +17,10 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
// TODO: figure out how to type the predicate
|
||||
const data = action.payload as ImageCache;
|
||||
const data = action.payload;
|
||||
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
const firstImage = imagesSelectors.selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
if (data.items.length > 0) {
|
||||
dispatch(imageSelected(data.items[0] ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
boardIdSelected,
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -14,14 +18,9 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
|
||||
const state = getState();
|
||||
|
||||
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
|
||||
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
|
||||
|
||||
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
|
||||
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
|
||||
|
||||
const queryArgs = { board_id: board_id ?? 'none', categories };
|
||||
dispatch(selectionChanged([]));
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
@@ -35,11 +34,12 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
|
||||
const selectedImage = boardImagesData.items.find(
|
||||
(item) => item.image_name === action.payload.selectedImageName
|
||||
);
|
||||
dispatch(imageSelected(selectedImage || null));
|
||||
} else if (boardImagesData) {
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage || null));
|
||||
dispatch(imageSelected(boardImagesData.items[0] || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -4,7 +4,6 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
@@ -32,14 +31,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (!listImagesData) {
|
||||
if (!queryResult.data) {
|
||||
// Should never happen if we have clicked a gallery image
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const imageDTOs = queryResult.data.items;
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
|
||||
/**
|
||||
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
|
||||
* update the selection.
|
||||
*
|
||||
* There are a three scenarios:
|
||||
*
|
||||
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
|
||||
*
|
||||
* 2. The page is changed by using the arrow keys (without alt).
|
||||
* - When going backwards, select the last image.
|
||||
* - When going forwards, select the first image.
|
||||
*
|
||||
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
|
||||
* - When going backwards, select the last image _as the comparison image_.
|
||||
* - When going forwards, select the first image _as the comparison image_.
|
||||
*/
|
||||
startAppListening({
|
||||
actionCreator: offsetChanged,
|
||||
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
|
||||
// Cancel any active listeners to prevent the selection from changing without user input
|
||||
cancelActiveListeners();
|
||||
|
||||
const { withHotkey } = action.payload;
|
||||
|
||||
if (!withHotkey) {
|
||||
// User changed pages by clicking the pagination buttons - no changes to selection
|
||||
return;
|
||||
}
|
||||
|
||||
const originalState = getOriginalState();
|
||||
const prevOffset = originalState.gallery.offset;
|
||||
const offset = getState().gallery.offset;
|
||||
|
||||
if (offset === prevOffset) {
|
||||
// The page didn't change - bail
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
|
||||
* page of images.
|
||||
*
|
||||
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
|
||||
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
|
||||
* changes to a cached page - a common situation - the `take` will never resolve.
|
||||
*
|
||||
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
|
||||
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
|
||||
* action, which updates the cache, then use the cache to update the selection.
|
||||
*/
|
||||
|
||||
// Check if we have data in the cache for the page of images
|
||||
const queryArgs = selectListImagesQueryArgs(getState());
|
||||
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
// No data yet - wait for the network request to complete
|
||||
if (!data) {
|
||||
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
|
||||
if (!takeResult) {
|
||||
// The request didn't complete in time - bail
|
||||
return;
|
||||
}
|
||||
data = takeResult[0].payload;
|
||||
}
|
||||
|
||||
// We awaited a network request - state could have changed, get fresh state
|
||||
const state = getState();
|
||||
const { selection, imageToCompare } = state.gallery;
|
||||
const imageDTOs = data?.items;
|
||||
|
||||
if (!imageDTOs) {
|
||||
// The page didn't load - bail
|
||||
return;
|
||||
}
|
||||
|
||||
if (withHotkey === 'arrow') {
|
||||
// User changed pages by using the arrow keys - selection changes to first or last image depending
|
||||
if (offset < prevOffset) {
|
||||
// We've gone backwards
|
||||
const lastImage = imageDTOs[imageDTOs.length - 1];
|
||||
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
|
||||
dispatch(selectionChanged(lastImage ? [lastImage] : []));
|
||||
}
|
||||
} else {
|
||||
// We've gone forwards
|
||||
const firstImage = imageDTOs[0];
|
||||
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
|
||||
dispatch(selectionChanged(firstImage ? [firstImage] : []));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (withHotkey === 'alt+arrow') {
|
||||
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
|
||||
if (offset < prevOffset) {
|
||||
// We've gone backwards
|
||||
const lastImage = imageDTOs[imageDTOs.length - 1];
|
||||
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
|
||||
dispatch(imageToCompareChanged(lastImage));
|
||||
}
|
||||
} else {
|
||||
// We've gone forwards
|
||||
const firstImage = imageDTOs[0];
|
||||
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
|
||||
dispatch(imageToCompareChanged(firstImage));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -22,11 +22,10 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
@@ -118,32 +117,7 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const state = getState();
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
|
||||
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
|
||||
|
||||
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
|
||||
|
||||
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
|
||||
|
||||
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
if (imageUsage.isCanvasImage) {
|
||||
@@ -168,6 +142,20 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
if (wasImageDeleted) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
}
|
||||
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
if (data && data.items) {
|
||||
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
|
||||
dispatch(imageSelected(newlySelectedImage || null));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -188,10 +176,8 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
if (data && data.items[0]) {
|
||||
dispatch(imageSelected(data.items[0]));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
@@ -15,7 +15,12 @@ import {
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
imageToCompareChanged,
|
||||
isImageViewerOpenChanged,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -216,6 +221,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -233,6 +239,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -248,6 +255,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -261,6 +269,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
},
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
ipaLayerImageChanged,
|
||||
rgLayerIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -62,7 +63,8 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
);
|
||||
|
||||
// Attempt to get the board's name for the toast
|
||||
const { data } = boardsApi.endpoints.listAllBoards.select()(state);
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const { data } = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
|
||||
// Fall back to just the board id if we can't find the board for some reason
|
||||
const board = data?.find((b) => b.board_id === autoAddBoardId);
|
||||
|
||||
@@ -8,14 +8,14 @@ import {
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
isImageViewerOpenChanged,
|
||||
offsetChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
@@ -52,24 +52,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
|
||||
if (!imageDTO.is_intermediate) {
|
||||
/**
|
||||
* Cache updates for when an image result is received
|
||||
* - add it to the no_board/images
|
||||
*/
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
'listImages',
|
||||
{
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: IMAGE_CATEGORIES,
|
||||
},
|
||||
(draft) => {
|
||||
imagesAdapter.addOne(draft, imageDTO);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
@@ -78,7 +60,18 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
|
||||
const { shouldAutoSwitch } = gallery;
|
||||
|
||||
@@ -98,6 +91,8 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(offsetChanged({ offset: 0 }));
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
|
||||
@@ -1,47 +1,37 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import type { MouseEvent, ReactElement } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
const sx: SystemStyleObject = {
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
};
|
||||
|
||||
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
|
||||
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
tooltip: string;
|
||||
icon?: ReactElement;
|
||||
styleOverrides?: SystemStyleObject;
|
||||
};
|
||||
|
||||
const IAIDndImageIcon = (props: Props) => {
|
||||
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||
|
||||
const sx = useMemo(
|
||||
() => ({
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
...styleOverrides,
|
||||
}),
|
||||
[styleOverrides]
|
||||
);
|
||||
const { onClick, tooltip, icon, ...rest } = props;
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
tooltip={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={sx}
|
||||
data-testid={tooltip}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
/**
|
||||
* Comparator function for sorting dates in ascending order
|
||||
*/
|
||||
export const dateComparator = (a: string, b: string) => {
|
||||
const dateA = new Date(a);
|
||||
const dateB = new Date(b);
|
||||
|
||||
// sort in ascending order
|
||||
if (dateA > dateB) {
|
||||
return 1;
|
||||
}
|
||||
if (dateA < dateB) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
isModalOpenChanged,
|
||||
selectChangeBoardModalSlice,
|
||||
} from 'features/changeBoardModal/store/slice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
@@ -20,7 +21,8 @@ const selectImagesToChange = createMemoizedSelector(
|
||||
const ChangeBoardModal = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
|
||||
const isModalOpen = useAppSelector((s) => s.changeBoardModal.isModalOpen);
|
||||
const imagesToChange = useAppSelector(selectImagesToChange);
|
||||
const [addImagesToBoard] = useAddImagesToBoardMutation();
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
@@ -185,25 +184,25 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSaveControlImage}
|
||||
icon={<PiFloppyDiskBold size={16} />}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={t('controlnet.setControlImageDimensions')}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{pendingControlImages.includes(id) && (
|
||||
<Flex
|
||||
@@ -226,6 +225,3 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
};
|
||||
|
||||
export default memo(ControlAdapterImagePreview);
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -160,7 +159,7 @@ export const ControlAdapterImagePreview = memo(
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
position="relative"
|
||||
w="full"
|
||||
w={36}
|
||||
h={36}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
@@ -193,25 +192,27 @@ export const ControlAdapterImagePreview = memo(
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSaveControlImage}
|
||||
icon={<PiFloppyDiskBold size={16} />}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={
|
||||
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{controlAdapter.processorPendingBatchId !== null && (
|
||||
<Flex
|
||||
@@ -235,6 +236,3 @@ export const ControlAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -82,7 +81,7 @@ export const IPAdapterImagePreview = memo(
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
|
||||
<Flex position="relative" w={36} h={36} alignItems="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
@@ -90,24 +89,25 @@ export const IPAdapterImagePreview = memo(
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={
|
||||
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -79,31 +78,34 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
}, [onReset, isConnected, isErrorControlImage]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={imageDTO}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
<Flex w="full" alignItems="center" justifyContent="center">
|
||||
<Flex position="relative" w={36} h={36} alignItems="center" justifyContent="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={imageDTO}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={onReset}
|
||||
icon={imageDTO ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={onUseSize}
|
||||
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={useSizeStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
{imageDTO && (
|
||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={onReset}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={onUseSize}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={
|
||||
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
InitialImagePreview.displayName = 'InitialImagePreview';
|
||||
|
||||
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -11,25 +11,28 @@ const BoardAutoAddSelect = () => {
|
||||
const { t } = useTranslation();
|
||||
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
||||
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
||||
const { options, hasBoards } = useListAllBoardsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const options: ComboboxOption[] = [
|
||||
{
|
||||
label: t('controlnet.none'),
|
||||
value: 'none',
|
||||
},
|
||||
].concat(
|
||||
(data ?? []).map(({ board_id, board_name }) => ({
|
||||
label: board_name,
|
||||
value: board_id,
|
||||
}))
|
||||
);
|
||||
return {
|
||||
options,
|
||||
hasBoards: options.length > 1,
|
||||
};
|
||||
},
|
||||
});
|
||||
const { options, hasBoards } = useListAllBoardsQuery(
|
||||
{},
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const options: ComboboxOption[] = [
|
||||
{
|
||||
label: t('controlnet.none'),
|
||||
value: 'none',
|
||||
},
|
||||
].concat(
|
||||
(data ?? []).map(({ board_id, board_name }) => ({
|
||||
label: board_name,
|
||||
value: board_id,
|
||||
}))
|
||||
);
|
||||
return {
|
||||
options,
|
||||
hasBoards: options.length > 1,
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
||||
@@ -3,11 +3,12 @@ import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-librar
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { autoAddBoardIdChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
|
||||
import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold } from 'react-icons/pi';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
@@ -15,52 +16,85 @@ import type { BoardDTO } from 'services/api/types';
|
||||
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
||||
|
||||
type Props = {
|
||||
board?: BoardDTO;
|
||||
board_id: BoardId;
|
||||
board: BoardDTO;
|
||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||
setBoardToDelete?: (board?: BoardDTO) => void;
|
||||
setBoardToDelete: (board?: BoardDTO) => void;
|
||||
};
|
||||
|
||||
const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props) => {
|
||||
const BoardContextMenu = ({ board, setBoardToDelete, children }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
||||
const selectIsSelectedForAutoAdd = useMemo(
|
||||
() => createSelector(selectGallerySlice, (gallery) => board && board.board_id === gallery.autoAddBoardId),
|
||||
[board]
|
||||
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
|
||||
[board.board_id]
|
||||
);
|
||||
|
||||
const [updateBoard] = useUpdateBoardMutation();
|
||||
|
||||
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
||||
const boardName = useBoardName(board_id);
|
||||
const boardName = useBoardName(board.board_id);
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||
|
||||
const [bulkDownload] = useBulkDownloadImagesMutation();
|
||||
|
||||
const handleSetAutoAdd = useCallback(() => {
|
||||
dispatch(autoAddBoardIdChanged(board_id));
|
||||
}, [board_id, dispatch]);
|
||||
dispatch(autoAddBoardIdChanged(board.board_id));
|
||||
}, [board.board_id, dispatch]);
|
||||
|
||||
const handleBulkDownload = useCallback(() => {
|
||||
bulkDownload({ image_names: [], board_id: board_id });
|
||||
}, [board_id, bulkDownload]);
|
||||
bulkDownload({ image_names: [], board_id: board.board_id });
|
||||
}, [board.board_id, bulkDownload]);
|
||||
|
||||
const handleArchive = useCallback(async () => {
|
||||
try {
|
||||
await updateBoard({
|
||||
board_id: board.board_id,
|
||||
changes: { archived: true },
|
||||
}).unwrap();
|
||||
} catch (error) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: 'Unable to archive board',
|
||||
});
|
||||
}
|
||||
}, [board.board_id, updateBoard]);
|
||||
|
||||
const handleUnarchive = useCallback(() => {
|
||||
updateBoard({
|
||||
board_id: board.board_id,
|
||||
changes: { archived: false },
|
||||
});
|
||||
}, [board.board_id, updateBoard]);
|
||||
|
||||
const renderMenuFunc = useCallback(
|
||||
() => (
|
||||
<MenuList visibility="visible">
|
||||
<MenuGroup title={boardName}>
|
||||
<MenuItem
|
||||
icon={<PiPlusBold />}
|
||||
isDisabled={isSelectedForAutoAdd || autoAssignBoardOnClick}
|
||||
onClick={handleSetAutoAdd}
|
||||
>
|
||||
{t('boards.menuItemAutoAdd')}
|
||||
</MenuItem>
|
||||
{!autoAssignBoardOnClick && (
|
||||
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
|
||||
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{isBulkDownloadEnabled && (
|
||||
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
|
||||
{t('boards.downloadBoard')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{board && <GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />}
|
||||
|
||||
{board.archived && (
|
||||
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive}>
|
||||
{t('boards.unarchiveBoard')}
|
||||
</MenuItem>
|
||||
)}
|
||||
|
||||
{!board.archived && (
|
||||
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive}>
|
||||
{t('boards.archiveBoard')}
|
||||
</MenuItem>
|
||||
)}
|
||||
|
||||
<GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />
|
||||
</MenuGroup>
|
||||
</MenuList>
|
||||
),
|
||||
@@ -74,6 +108,8 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
|
||||
isSelectedForAutoAdd,
|
||||
setBoardToDelete,
|
||||
t,
|
||||
handleArchive,
|
||||
handleUnarchive,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
type Props = {
|
||||
board_id: string;
|
||||
isArchived: boolean;
|
||||
};
|
||||
|
||||
export const BoardTotalsTooltip = ({ board_id, isArchived }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
|
||||
};
|
||||
@@ -2,6 +2,7 @@ import { Collapse, Flex, Grid, GridItem } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useState } from 'react';
|
||||
@@ -26,7 +27,8 @@ const BoardsList = (props: Props) => {
|
||||
const { isOpen } = props;
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
|
||||
const { data: boards } = useListAllBoardsQuery();
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { data: boards } = useListAllBoardsQuery(queryArgs);
|
||||
const filteredBoards = boardSearchText
|
||||
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
|
||||
: boards;
|
||||
|
||||
@@ -8,15 +8,12 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { AddToBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesSquare } from 'react-icons/pi';
|
||||
import {
|
||||
useGetBoardAssetsTotalQuery,
|
||||
useGetBoardImagesTotalQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/api/endpoints/boards';
|
||||
import { PiArchiveBold, PiImagesSquare } from 'react-icons/pi';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
@@ -28,6 +25,14 @@ const editableInputStyles: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const ArchivedIcon = () => {
|
||||
return (
|
||||
<Box position="absolute" top={1} insetInlineEnd={2} p={0} minW={0}>
|
||||
<Icon as={PiArchiveBold} fill="base.300" filter="drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))" />
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
interface GalleryBoardProps {
|
||||
board: BoardDTO;
|
||||
isSelected: boolean;
|
||||
@@ -36,6 +41,7 @@ interface GalleryBoardProps {
|
||||
|
||||
const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
||||
const selectIsSelectedForAutoAdd = useMemo(
|
||||
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
|
||||
@@ -51,17 +57,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
@@ -117,7 +112,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
const handleChange = useCallback((newBoardName: string) => {
|
||||
setLocalBoardName(newBoardName);
|
||||
}, []);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Box w="full" h="full" userSelect="none">
|
||||
<Flex
|
||||
@@ -130,9 +125,12 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
w="full"
|
||||
h="full"
|
||||
>
|
||||
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
|
||||
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
|
||||
{(ref) => (
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Tooltip
|
||||
label={<BoardTotalsTooltip board_id={board.board_id} isArchived={Boolean(board.archived)} />}
|
||||
openDelay={1000}
|
||||
>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
@@ -145,6 +143,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
cursor="pointer"
|
||||
bg="base.800"
|
||||
>
|
||||
{board.archived && <ArchivedIcon />}
|
||||
{coverImage?.thumbnail_url ? (
|
||||
<Image
|
||||
src={coverImage?.thumbnail_url}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user