mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 03:18:05 -05:00
Compare commits
75 Commits
test/test-
...
fix/instal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c37c48b8c | ||
|
|
0cfe2ccd9d | ||
|
|
b6f356f067 | ||
|
|
a4f1db7c02 | ||
|
|
21206bafcf | ||
|
|
a047bad391 | ||
|
|
909afc266e | ||
|
|
4039dd148d | ||
|
|
ea0f8b8791 | ||
|
|
f412582d60 | ||
|
|
c5672adb6b | ||
|
|
0e5c3a641a | ||
|
|
9015e72e1e | ||
|
|
6b05d27c7a | ||
|
|
19d0673085 | ||
|
|
048b4fe7e8 | ||
|
|
e8b83fecff | ||
|
|
8883ecb2bf | ||
|
|
2f97f1d6d5 | ||
|
|
73d6cc824b | ||
|
|
acc0a29dca | ||
|
|
38c1436f02 | ||
|
|
efbdb75568 | ||
|
|
8929495aeb | ||
|
|
428f0b265f | ||
|
|
7daee41ad2 | ||
|
|
7cdd7b6ad7 | ||
|
|
bc64cde6f9 | ||
|
|
4465f97cdf | ||
|
|
fface2cda7 | ||
|
|
7fcb8959fb | ||
|
|
dcf0dc4274 | ||
|
|
bb52861896 | ||
|
|
f2d26a3a3c | ||
|
|
04d8f2dfea | ||
|
|
355d4cf4e2 | ||
|
|
a3a828779a | ||
|
|
8c71ff37ae | ||
|
|
ddb65e6034 | ||
|
|
3a0ec635c9 | ||
|
|
8afe517204 | ||
|
|
5eaea9dd64 | ||
|
|
ef8dcf5fae | ||
|
|
024a156114 | ||
|
|
7ea2a135f1 | ||
|
|
af2264b6eb | ||
|
|
41bf9ec4a3 | ||
|
|
2b36565e9e | ||
|
|
f2c3b7c317 | ||
|
|
67751a01ab | ||
|
|
cb8cdefd59 | ||
|
|
f1c846ba5c | ||
|
|
3a6ba236f5 | ||
|
|
bd56e9bc81 | ||
|
|
b55fc2935e | ||
|
|
0544917161 | ||
|
|
1161dfe055 | ||
|
|
433f347d7e | ||
|
|
33a412a24f | ||
|
|
9316534d97 | ||
|
|
fdaa661245 | ||
|
|
f1c195afb7 | ||
|
|
3b363d0258 | ||
|
|
36e0faea6b | ||
|
|
927f8a66e6 | ||
|
|
eebc0e7315 | ||
|
|
6b173cc66f | ||
|
|
b4732a7308 | ||
|
|
344a56327a | ||
|
|
ce22c0fbaa | ||
|
|
55f8865524 | ||
|
|
2d051559d1 | ||
|
|
db9cef0092 | ||
|
|
72c34aea75 | ||
|
|
edeea5237b |
2
.github/workflows/style-checks.yml
vendored
2
.github/workflows/style-checks.yml
vendored
@@ -6,7 +6,7 @@ on:
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
black:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
@@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
_For Macintoshes, either Intel or M1/M2/M3:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517
|
||||
|
||||
1213
docs/contributing/MODEL_MANAGER.md
Normal file
1213
docs/contributing/MODEL_MANAGER.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -179,7 +179,7 @@ experimental versions later.
|
||||
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
|
||||
or CPU (no graphics acceleration). On Windows, you'll have the
|
||||
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
|
||||
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
|
||||
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
|
||||
graphics acceleration without installing additional drivers. If you
|
||||
are unsure what GPU you are using, you can ask the installer to
|
||||
guess.
|
||||
@@ -471,7 +471,7 @@ Then type the following commands:
|
||||
|
||||
=== "NVIDIA System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -327,7 +327,7 @@ installation protocol (important!)
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
```bash
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.10
|
||||
conda activate invokeai
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
invokeai-configure --root ~/invokeai
|
||||
invokeai --root ~/invokeai --web
|
||||
```
|
||||
|
||||
@@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||
https://download.pytorch.org/whl/cu121` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
## :simple-amd: ROCm
|
||||
|
||||
@@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
|
||||
The container is configured for CUDA by default, but can be built to support AMD GPUs
|
||||
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
|
||||
|
||||
Developers on Apple silicon (M1/M2): You
|
||||
Developers on Apple silicon (M1/M2/M3): You
|
||||
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
|
||||
and performance is reduced compared with running it directly on macOS but for
|
||||
development purposes it's fine. Once you're done with development tasks on your
|
||||
|
||||
@@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
|
||||
Then run the following three commands:
|
||||
|
||||
```sh
|
||||
pip install xformers~=0.0.19
|
||||
pip install xformers~=0.0.22
|
||||
pip install triton # WON'T WORK ON WINDOWS
|
||||
python -m xformers.info output
|
||||
```
|
||||
@@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
|
||||
following:
|
||||
|
||||
```sh
|
||||
xFormers 0.0.20
|
||||
xFormers 0.0.22
|
||||
memory_efficient_attention.cutlassF: available
|
||||
memory_efficient_attention.cutlassB: available
|
||||
memory_efficient_attention.flshattF: available
|
||||
@@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
|
||||
swiglu.fused.p.cpp: available
|
||||
is_triton_available: True
|
||||
is_functorch_available: False
|
||||
pytorch.version: 2.0.1+cu118
|
||||
pytorch.version: 2.1.0+cu121
|
||||
pytorch.cuda: available
|
||||
gpu.compute_capability: 8.9
|
||||
gpu.name: NVIDIA GeForce RTX 4070
|
||||
build.info: available
|
||||
build.cuda_version: 1108
|
||||
build.python_version: 3.10.11
|
||||
build.torch_version: 2.0.1+cu118
|
||||
build.torch_version: 2.1.0+cu121
|
||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||
build.env.XFORMERS_BUILD_TYPE: Release
|
||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||
@@ -92,33 +92,22 @@ installed from source. These instructions were written for a system
|
||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||
adapt this recipe.
|
||||
|
||||
#### 1. Install CUDA Toolkit 11.8
|
||||
#### 1. Install CUDA Toolkit 12.1
|
||||
|
||||
You will need the CUDA developer's toolkit in order to compile and
|
||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 12.1
|
||||
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
|
||||
and use the target selection wizard to choose your platform and Linux
|
||||
distribution. Select an installer type of "runfile (local)" at the
|
||||
last step.
|
||||
|
||||
This will provide you with a recipe for downloading and running a
|
||||
install shell script that will install the toolkit and drivers. For
|
||||
example, the install script recipe for Ubuntu 22.04 running on a
|
||||
x86_64 system is:
|
||||
install shell script that will install the toolkit and drivers.
|
||||
|
||||
```
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
|
||||
Rather than cut-and-paste this example, We recommend that you walk
|
||||
through the toolkit wizard in order to get the most up to date
|
||||
installer for your system.
|
||||
|
||||
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
|
||||
|
||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||
installed. If not, you can check whether you have the needed libraries
|
||||
@@ -133,7 +122,7 @@ Then run the command:
|
||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||
```
|
||||
|
||||
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||
If it prints __2.1.0+cu121__ you're good. If not, you can install the
|
||||
most up to date libraries with this command:
|
||||
|
||||
```sh
|
||||
|
||||
@@ -244,7 +244,7 @@ class InvokeAiInstance:
|
||||
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
||||
"urllib3~=1.26.0",
|
||||
"requests~=2.28.0",
|
||||
"torch~=2.0.0",
|
||||
"torch~=2.1.0",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
@@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
@@ -24,6 +24,7 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
@@ -85,6 +86,7 @@ class ApiDependencies:
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
@@ -111,6 +113,7 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
||||
164
invokeai/app/api/routers/model_records.py
Normal file
164
invokeai/app/api/routers/model_records.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: list[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||
else:
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store.del_model(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
board_images,
|
||||
boards,
|
||||
images,
|
||||
model_records,
|
||||
models,
|
||||
session_queue,
|
||||
sessions,
|
||||
@@ -106,6 +107,7 @@ app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
app.include_router(model_records.model_records_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
@@ -112,10 +112,11 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@@ -234,10 +235,11 @@ class SDXLPromptInvocationBase:
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@@ -49,6 +50,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@@ -76,6 +78,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@@ -101,6 +104,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
8
invokeai/app/services/model_records/__init__.py
Normal file
8
invokeai/app/services/model_records/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Init file for model record services."""
|
||||
from .model_records_base import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||
169
invokeai/app/services/model_records/model_records_base.py
Normal file
169
invokeai/app/services/model_records/model_records_base.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Abstract base class for storing and retrieving model configuration records.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2.0"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
"""Raised when an invalid model is detected."""
|
||||
|
||||
|
||||
class UnknownModelException(Exception):
|
||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
||||
|
||||
|
||||
class ConfigFileVersionMismatchException(Exception):
|
||||
"""Raised on an attempt to open a config with an incompatible version."""
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_path(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
) -> List[AnyModelConfig]:
|
||||
"""Return the model(s) having the indicated path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_hash(
|
||||
self,
|
||||
hash: str,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""Return the model(s) having the indicated original hash."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_attr(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
def all_models(self) -> List[AnyModelConfig]:
|
||||
"""Return all the model configs in the database."""
|
||||
return self.search_by_attr()
|
||||
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
|
||||
"""
|
||||
Return information about a single model using its name, base type and model type.
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException(
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException(
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
397
invokeai/app/services/model_records/model_records_sql.py
Normal file
397
invokeai/app/services/model_records/model_records_sql.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Implementation of the ModelRecordServiceBase API
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigStoreSQL
|
||||
store = ModelConfigStoreSQL(sqlite_db)
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
name='old name',
|
||||
base_model='sd-1',
|
||||
type='embedding',
|
||||
format='embedding_file',
|
||||
)
|
||||
|
||||
# adding - the key becomes the model's "key" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
config.name='new name'
|
||||
store.update_model('key1', config)
|
||||
|
||||
# checking for existence
|
||||
if store.exists('key1'):
|
||||
print("yes")
|
||||
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base)
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_path(path='/tmp/pokemon.bin')
|
||||
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
|
||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
from ..shared.sqlite import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
_db: SqliteDatabase
|
||||
_cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._db.lock:
|
||||
# Enable foreign keys
|
||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._db.conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
# model_config table breaks out the fields that are common to all config objects
|
||||
# and puts class-specific ones in a serialized json object
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add indexes for searchable fields
|
||||
for stmt in [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]:
|
||||
self._cursor.execute(stmt)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base,
|
||||
type,
|
||||
name,
|
||||
path,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base,
|
||||
record.type,
|
||||
record.name,
|
||||
record.path,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base=?,
|
||||
type=?,
|
||||
name=?,
|
||||
path=?,
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE model_path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
323
invokeai/backend/model_manager/config.py
Normal file
323
invokeai/backend/model_manager/config.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Configuration definitions for image generation models.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigFactory
|
||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||
name='foo',
|
||||
base='sd-1',
|
||||
type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
variant='normal',
|
||||
format='checkpoint'
|
||||
)
|
||||
config = ModelConfigFactory.make_config(raw)
|
||||
print(config.name)
|
||||
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Model type."""
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
"""Scheduler prediction type."""
|
||||
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: ModelFormat
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(default=None)
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
def update(self, attributes: dict):
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
|
||||
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
|
||||
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD2Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-2."""
|
||||
|
||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||
upcast_attention: bool = True
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
:param model_data: A raw dict corresponding the obect fields to be
|
||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||
object, which will be passed through unchanged.
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.validate_python(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
if key:
|
||||
model.key = key
|
||||
return model
|
||||
66
invokeai/backend/model_manager/hash.py
Normal file
66
invokeai/backend/model_manager/hash.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Fast hashing of diffusers and checkpoint-style models.
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
from imohash import hashfile
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
93
invokeai/backend/model_manager/migrate_to_db.py
Normal file
93
invokeai/backend/model_manager/migrate_to_db.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
from hashlib import sha1
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class MigrateModelYamlToDb:
|
||||
"""
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
|
||||
|
||||
The class has one externally useful method, migrate(), which scans the
|
||||
currently models.yaml file and imports all its entries into invokeai.db.
|
||||
|
||||
Use this way:
|
||||
|
||||
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: InvokeAILogger
|
||||
|
||||
def __init__(self):
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
|
||||
def get_db(self) -> ModelRecordServiceSQL:
|
||||
"""Fetch the sqlite3 database for this installation."""
|
||||
db = SqliteDatabase(self.config, self.logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
return OmegaConf.load(yaml_path)
|
||||
|
||||
def migrate(self):
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
db = self.get_db()
|
||||
yaml = self.get_yaml()
|
||||
|
||||
for model_key, stanza in yaml.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
|
||||
new_config = ModelsValidator.validate_python(stanza)
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
try:
|
||||
db.add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
|
||||
|
||||
def main():
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -748,7 +748,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [
|
||||
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)
|
||||
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,7 @@ import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
@@ -14,8 +15,10 @@ from threading import Thread
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import logging as diffusers_logging
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tqdm import tqdm
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@@ -379,3 +382,21 @@ class Chdir(object):
|
||||
|
||||
def __exit__(self, *args):
|
||||
os.chdir(self.original)
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
||||
|
||||
def __enter__(self):
|
||||
"""Set verbosity to error."""
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
"""Restore logger verbosity to state before context was entered."""
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
@@ -90,6 +90,14 @@ def get_extras():
|
||||
pass
|
||||
return extras
|
||||
|
||||
def get_extra_index() -> str:
|
||||
# parsed_version.local for torch is the platform + version, eg 'cu121' or 'rocm5.6'
|
||||
local = pkg_resources.get_distribution("torch").parsed_version.local
|
||||
if local and 'cu' in local:
|
||||
return "--extra-index-url https://download.pytorch.org/whl/cu121"
|
||||
if local and 'rocm' in local:
|
||||
return "--extra-index-url https://download.pytorch.org/whl/rocm5.6"
|
||||
return ""
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
@@ -122,14 +130,15 @@ def main():
|
||||
branch = Prompt.ask("Enter an InvokeAI branch name")
|
||||
|
||||
extras = get_extras()
|
||||
extra_index_url = get_extra_index()
|
||||
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
|
||||
if release:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade {extra_index_url}'
|
||||
elif tag:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade {extra_index_url}'
|
||||
else:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade {extra_index_url}'
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
|
||||
@@ -24,6 +24,7 @@ module.exports = {
|
||||
root: true,
|
||||
rules: {
|
||||
curly: 'error',
|
||||
'react/jsx-no-bind': ['error', { allowBind: true }],
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{ props: 'never', children: 'never' },
|
||||
|
||||
171
invokeai/frontend/web/dist/assets/App-d620b60d.js
vendored
Normal file
171
invokeai/frontend/web/dist/assets/App-d620b60d.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-17a58e64.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-17a58e64.js
vendored
Normal file
File diff suppressed because one or more lines are too long
280
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-58d6b3b6.js
vendored
Normal file
280
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-58d6b3b6.js
vendored
Normal file
@@ -0,0 +1,280 @@
|
||||
import{w as s,ia as T,v as l,a2 as I,ib as R,ae as V,ic as z,id as j,ie as D,ig as F,ih as G,ii as W,ij as K,aG as H,ik as U,il as Y}from"./index-54a1ea80.js";import{M as Z}from"./MantineProvider-17a58e64.js";var P=String.raw,E=P`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
||||
|
||||
@supports (height: -webkit-fill-available) {
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: -webkit-fill-available;
|
||||
}
|
||||
}
|
||||
|
||||
@supports (height: -moz-fill-available) {
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: -moz-fill-available;
|
||||
}
|
||||
}
|
||||
|
||||
@supports (height: 100dvh) {
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100dvh;
|
||||
}
|
||||
}
|
||||
`,B=()=>s.jsx(T,{styles:E}),J=({scope:e=""})=>s.jsx(T,{styles:P`
|
||||
html {
|
||||
line-height: 1.5;
|
||||
-webkit-text-size-adjust: 100%;
|
||||
font-family: system-ui, sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
text-rendering: optimizeLegibility;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
touch-action: manipulation;
|
||||
}
|
||||
|
||||
body {
|
||||
position: relative;
|
||||
min-height: 100%;
|
||||
margin: 0;
|
||||
font-feature-settings: "kern";
|
||||
}
|
||||
|
||||
${e} :where(*, *::before, *::after) {
|
||||
border-width: 0;
|
||||
border-style: solid;
|
||||
box-sizing: border-box;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
main {
|
||||
display: block;
|
||||
}
|
||||
|
||||
${e} hr {
|
||||
border-top-width: 1px;
|
||||
box-sizing: content-box;
|
||||
height: 0;
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
${e} :where(pre, code, kbd,samp) {
|
||||
font-family: SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 1em;
|
||||
}
|
||||
|
||||
${e} a {
|
||||
background-color: transparent;
|
||||
color: inherit;
|
||||
text-decoration: inherit;
|
||||
}
|
||||
|
||||
${e} abbr[title] {
|
||||
border-bottom: none;
|
||||
text-decoration: underline;
|
||||
-webkit-text-decoration: underline dotted;
|
||||
text-decoration: underline dotted;
|
||||
}
|
||||
|
||||
${e} :where(b, strong) {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
${e} small {
|
||||
font-size: 80%;
|
||||
}
|
||||
|
||||
${e} :where(sub,sup) {
|
||||
font-size: 75%;
|
||||
line-height: 0;
|
||||
position: relative;
|
||||
vertical-align: baseline;
|
||||
}
|
||||
|
||||
${e} sub {
|
||||
bottom: -0.25em;
|
||||
}
|
||||
|
||||
${e} sup {
|
||||
top: -0.5em;
|
||||
}
|
||||
|
||||
${e} img {
|
||||
border-style: none;
|
||||
}
|
||||
|
||||
${e} :where(button, input, optgroup, select, textarea) {
|
||||
font-family: inherit;
|
||||
font-size: 100%;
|
||||
line-height: 1.15;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
${e} :where(button, input) {
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
${e} :where(button, select) {
|
||||
text-transform: none;
|
||||
}
|
||||
|
||||
${e} :where(
|
||||
button::-moz-focus-inner,
|
||||
[type="button"]::-moz-focus-inner,
|
||||
[type="reset"]::-moz-focus-inner,
|
||||
[type="submit"]::-moz-focus-inner
|
||||
) {
|
||||
border-style: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
${e} fieldset {
|
||||
padding: 0.35em 0.75em 0.625em;
|
||||
}
|
||||
|
||||
${e} legend {
|
||||
box-sizing: border-box;
|
||||
color: inherit;
|
||||
display: table;
|
||||
max-width: 100%;
|
||||
padding: 0;
|
||||
white-space: normal;
|
||||
}
|
||||
|
||||
${e} progress {
|
||||
vertical-align: baseline;
|
||||
}
|
||||
|
||||
${e} textarea {
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
${e} :where([type="checkbox"], [type="radio"]) {
|
||||
box-sizing: border-box;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
${e} input[type="number"]::-webkit-inner-spin-button,
|
||||
${e} input[type="number"]::-webkit-outer-spin-button {
|
||||
-webkit-appearance: none !important;
|
||||
}
|
||||
|
||||
${e} input[type="number"] {
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
|
||||
${e} input[type="search"] {
|
||||
-webkit-appearance: textfield;
|
||||
outline-offset: -2px;
|
||||
}
|
||||
|
||||
${e} input[type="search"]::-webkit-search-decoration {
|
||||
-webkit-appearance: none !important;
|
||||
}
|
||||
|
||||
${e} ::-webkit-file-upload-button {
|
||||
-webkit-appearance: button;
|
||||
font: inherit;
|
||||
}
|
||||
|
||||
${e} details {
|
||||
display: block;
|
||||
}
|
||||
|
||||
${e} summary {
|
||||
display: list-item;
|
||||
}
|
||||
|
||||
template {
|
||||
display: none;
|
||||
}
|
||||
|
||||
[hidden] {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
${e} :where(
|
||||
blockquote,
|
||||
dl,
|
||||
dd,
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
h6,
|
||||
hr,
|
||||
figure,
|
||||
p,
|
||||
pre
|
||||
) {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
${e} button {
|
||||
background: transparent;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
${e} fieldset {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
${e} :where(ol, ul) {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
${e} textarea {
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
${e} :where(button, [role="button"]) {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
${e} button::-moz-focus-inner {
|
||||
border: 0 !important;
|
||||
}
|
||||
|
||||
${e} table {
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
${e} :where(h1, h2, h3, h4, h5, h6) {
|
||||
font-size: inherit;
|
||||
font-weight: inherit;
|
||||
}
|
||||
|
||||
${e} :where(button, input, optgroup, select, textarea) {
|
||||
padding: 0;
|
||||
line-height: inherit;
|
||||
color: inherit;
|
||||
}
|
||||
|
||||
${e} :where(img, svg, video, canvas, audio, iframe, embed, object) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
${e} :where(img, video) {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
[data-js-focus-visible]
|
||||
:focus:not([data-focus-visible-added]):not(
|
||||
[data-focus-visible-disabled]
|
||||
) {
|
||||
outline: none;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
${e} select::-ms-expand {
|
||||
display: none;
|
||||
}
|
||||
|
||||
${E}
|
||||
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};
|
||||
156
invokeai/frontend/web/dist/assets/index-54a1ea80.js
vendored
Normal file
156
invokeai/frontend/web/dist/assets/index-54a1ea80.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import queueReducer from 'features/queue/store/queueSlice';
|
||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
|
||||
@@ -8,7 +8,14 @@ import {
|
||||
forwardRef,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
||||
import {
|
||||
cloneElement,
|
||||
memo,
|
||||
ReactElement,
|
||||
ReactNode,
|
||||
useCallback,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIButton from './IAIButton';
|
||||
|
||||
@@ -38,15 +45,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const handleAccept = () => {
|
||||
const handleAccept = useCallback(() => {
|
||||
acceptCallback();
|
||||
onClose();
|
||||
};
|
||||
}, [acceptCallback, onClose]);
|
||||
|
||||
const handleCancel = () => {
|
||||
const handleCancel = useCallback(() => {
|
||||
cancelCallback && cancelCallback();
|
||||
onClose();
|
||||
};
|
||||
}, [cancelCallback, onClose]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import { Box, Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaExclamation } from 'react-icons/fa';
|
||||
|
||||
const IAIErrorLoadingImageFallback = () => {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
height: 'full',
|
||||
width: 'full',
|
||||
'::before': {
|
||||
content: "''",
|
||||
display: 'block',
|
||||
pt: '100%',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
height: 'full',
|
||||
width: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.100',
|
||||
color: 'base.500',
|
||||
_dark: {
|
||||
color: 'base.700',
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIErrorLoadingImageFallback);
|
||||
@@ -1,8 +0,0 @@
|
||||
import { chakra } from '@chakra-ui/react';
|
||||
|
||||
/**
|
||||
* Chakra-enabled <form />
|
||||
*/
|
||||
const IAIForm = chakra.form;
|
||||
|
||||
export default IAIForm;
|
||||
@@ -1,15 +0,0 @@
|
||||
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
type IAIFormErrorMessageProps = FormErrorMessageProps & {
|
||||
children: ReactNode | string;
|
||||
};
|
||||
|
||||
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
|
||||
const { children, ...rest } = props;
|
||||
return (
|
||||
<FormErrorMessage color="error.400" {...rest}>
|
||||
{children}
|
||||
</FormErrorMessage>
|
||||
);
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
type IAIFormHelperTextProps = FormHelperTextProps & {
|
||||
children: ReactNode | string;
|
||||
};
|
||||
|
||||
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
|
||||
const { children, ...rest } = props;
|
||||
return (
|
||||
<FormHelperText margin={0} color="base.400" {...rest}>
|
||||
{children}
|
||||
</FormHelperText>
|
||||
);
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import { Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
export function IAIFormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
const { colorMode } = useColorMode();
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: mode('base.100', 'base.900')(colorMode),
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import {
|
||||
Checkbox,
|
||||
CheckboxProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
type IAIFullCheckboxProps = CheckboxProps & {
|
||||
label: string | ReactNode;
|
||||
formControlProps?: FormControlProps;
|
||||
};
|
||||
|
||||
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
|
||||
const { label, formControlProps, ...rest } = props;
|
||||
return (
|
||||
<FormControl {...formControlProps}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Checkbox colorScheme="accent" {...rest} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIFullCheckbox);
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useColorMode } from '@chakra-ui/react';
|
||||
import { TextInput, TextInputProps } from '@mantine/core';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useCallback } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMantineTextInputProps = TextInputProps;
|
||||
@@ -20,26 +21,37 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
||||
} = useChakraThemeTokens();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<TextInput
|
||||
styles={() => ({
|
||||
input: {
|
||||
color: mode(base900, base100)(colorMode),
|
||||
backgroundColor: mode(base50, base900)(colorMode),
|
||||
borderColor: mode(base200, base800)(colorMode),
|
||||
borderWidth: 2,
|
||||
outline: 'none',
|
||||
':focus': {
|
||||
borderColor: mode(accent300, accent500)(colorMode),
|
||||
},
|
||||
const stylesFunc = useCallback(
|
||||
() => ({
|
||||
input: {
|
||||
color: mode(base900, base100)(colorMode),
|
||||
backgroundColor: mode(base50, base900)(colorMode),
|
||||
borderColor: mode(base200, base800)(colorMode),
|
||||
borderWidth: 2,
|
||||
outline: 'none',
|
||||
':focus': {
|
||||
borderColor: mode(accent300, accent500)(colorMode),
|
||||
},
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal',
|
||||
marginBottom: 4,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
/>
|
||||
},
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal' as const,
|
||||
marginBottom: 4,
|
||||
},
|
||||
}),
|
||||
[
|
||||
accent300,
|
||||
accent500,
|
||||
base100,
|
||||
base200,
|
||||
base300,
|
||||
base50,
|
||||
base700,
|
||||
base800,
|
||||
base900,
|
||||
colorMode,
|
||||
]
|
||||
);
|
||||
|
||||
return <TextInput styles={stylesFunc} {...rest} />;
|
||||
}
|
||||
|
||||
@@ -98,28 +98,34 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
|
||||
}
|
||||
}, [value, valueAsString]);
|
||||
|
||||
const handleOnChange = (v: string) => {
|
||||
setValueAsString(v);
|
||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||
if (!v.match(numberStringRegex)) {
|
||||
// Cast the value to number. Floor it if it should be an integer.
|
||||
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
||||
}
|
||||
};
|
||||
const handleOnChange = useCallback(
|
||||
(v: string) => {
|
||||
setValueAsString(v);
|
||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||
if (!v.match(numberStringRegex)) {
|
||||
// Cast the value to number. Floor it if it should be an integer.
|
||||
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
|
||||
}
|
||||
},
|
||||
[isInteger, onChange]
|
||||
);
|
||||
|
||||
/**
|
||||
* Clicking the steppers allows the value to go outside bounds; we need to
|
||||
* clamp it on blur and floor it if needed.
|
||||
*/
|
||||
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
|
||||
const clamped = clamp(
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||
min,
|
||||
max
|
||||
);
|
||||
setValueAsString(String(clamped));
|
||||
onChange(clamped);
|
||||
};
|
||||
const handleBlur = useCallback(
|
||||
(e: FocusEvent<HTMLInputElement>) => {
|
||||
const clamped = clamp(
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||
min,
|
||||
max
|
||||
);
|
||||
setValueAsString(String(clamped));
|
||||
onChange(clamped);
|
||||
},
|
||||
[isInteger, max, min, onChange]
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, MouseEvent } from 'react';
|
||||
import { memo, MouseEvent, useCallback } from 'react';
|
||||
import IAIOption from './IAIOption';
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
@@ -33,15 +33,16 @@ const IAISelect = (props: IAISelectProps) => {
|
||||
spaceEvenly,
|
||||
...rest
|
||||
} = props;
|
||||
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||
e.stopPropagation();
|
||||
e.nativeEvent.stopImmediatePropagation();
|
||||
e.nativeEvent.stopPropagation();
|
||||
e.nativeEvent.cancelBubble = true;
|
||||
}, []);
|
||||
return (
|
||||
<FormControl
|
||||
isDisabled={isDisabled}
|
||||
onClick={(e: MouseEvent<HTMLDivElement>) => {
|
||||
e.stopPropagation();
|
||||
e.nativeEvent.stopImmediatePropagation();
|
||||
e.nativeEvent.stopPropagation();
|
||||
e.nativeEvent.cancelBubble = true;
|
||||
}}
|
||||
onClick={handleClick}
|
||||
sx={
|
||||
horizontal
|
||||
? {
|
||||
|
||||
@@ -186,6 +186,13 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
|
||||
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
|
||||
const handleStepperClick = useCallback(
|
||||
() => onChange(Number(localInputValue)),
|
||||
[localInputValue, onChange]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
ref={ref}
|
||||
@@ -219,8 +226,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
max={max}
|
||||
step={step}
|
||||
onChange={handleSliderChange}
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
focusThumbOnChange={false}
|
||||
isDisabled={isDisabled}
|
||||
{...rest}
|
||||
@@ -332,12 +339,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
{...sliderNumberInputFieldProps}
|
||||
/>
|
||||
<NumberInputStepper {...sliderNumberInputStepperProps}>
|
||||
<NumberIncrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
/>
|
||||
<NumberDecrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
/>
|
||||
<NumberIncrementStepper onClick={handleStepperClick} />
|
||||
<NumberDecrementStepper onClick={handleStepperClick} />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
)}
|
||||
|
||||
@@ -146,16 +146,15 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
};
|
||||
}, [inputRef]);
|
||||
|
||||
const handleKeyDown = useCallback((e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') {
|
||||
return;
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
onKeyDown={(e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') {
|
||||
return;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
|
||||
<input {...getInputProps()} />
|
||||
{children}
|
||||
<AnimatePresence>
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import { Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
const SelectImagePlaceholder = () => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
// bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SelectImagePlaceholder);
|
||||
@@ -1,24 +0,0 @@
|
||||
import { useBreakpoint } from '@chakra-ui/react';
|
||||
|
||||
export default function useResolution():
|
||||
| 'mobile'
|
||||
| 'tablet'
|
||||
| 'desktop'
|
||||
| 'unknown' {
|
||||
const breakpointValue = useBreakpoint();
|
||||
|
||||
const mobileResolutions = ['base', 'sm'];
|
||||
const tabletResolutions = ['md', 'lg'];
|
||||
const desktopResolutions = ['xl', '2xl'];
|
||||
|
||||
if (mobileResolutions.includes(breakpointValue)) {
|
||||
return 'mobile';
|
||||
}
|
||||
if (tabletResolutions.includes(breakpointValue)) {
|
||||
return 'tablet';
|
||||
}
|
||||
if (desktopResolutions.includes(breakpointValue)) {
|
||||
return 'desktop';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () =>
|
||||
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
||||
@@ -1,71 +0,0 @@
|
||||
// TODO: Restore variations
|
||||
// Support code from v2.3 in here.
|
||||
|
||||
// export const stringToSeedWeights = (
|
||||
// string: string
|
||||
// ): InvokeAI.SeedWeights | boolean => {
|
||||
// const stringPairs = string.split(',');
|
||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
|
||||
// return { seed: Number(p[0]), weight: Number(p[1]) };
|
||||
// });
|
||||
|
||||
// if (!validateSeedWeights(pairs)) {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// return pairs;
|
||||
// };
|
||||
|
||||
// export const validateSeedWeights = (
|
||||
// seedWeights: InvokeAI.SeedWeights | string
|
||||
// ): boolean => {
|
||||
// return typeof seedWeights === 'string'
|
||||
// ? Boolean(stringToSeedWeights(seedWeights))
|
||||
// : Boolean(
|
||||
// seedWeights.length &&
|
||||
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
|
||||
// const { seed, weight } = pair;
|
||||
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
||||
// const isWeightValid =
|
||||
// !isNaN(parseInt(weight.toString(), 10)) &&
|
||||
// weight >= 0 &&
|
||||
// weight <= 1;
|
||||
// return !(isSeedValid && isWeightValid);
|
||||
// })
|
||||
// );
|
||||
// };
|
||||
|
||||
// export const seedWeightsToString = (
|
||||
// seedWeights: InvokeAI.SeedWeights
|
||||
// ): string => {
|
||||
// return seedWeights.reduce((acc, pair, i, arr) => {
|
||||
// const { seed, weight } = pair;
|
||||
// acc += `${seed}:${weight}`;
|
||||
// if (i !== arr.length - 1) {
|
||||
// acc += ',';
|
||||
// }
|
||||
// return acc;
|
||||
// }, '');
|
||||
// };
|
||||
|
||||
// export const seedWeightsToArray = (
|
||||
// seedWeights: InvokeAI.SeedWeights
|
||||
// ): Array<Array<number>> => {
|
||||
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
|
||||
// pair.seed,
|
||||
// pair.weight,
|
||||
// ]);
|
||||
// };
|
||||
|
||||
// export const stringToSeedWeightsArray = (
|
||||
// string: string
|
||||
// ): Array<Array<number>> => {
|
||||
// const stringPairs = string.split(',');
|
||||
// const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||
// return arrPairs.map(
|
||||
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
|
||||
// );
|
||||
// };
|
||||
|
||||
export default {};
|
||||
@@ -5,17 +5,22 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { isStagingSelector } from '../store/canvasSelectors';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const ClearCanvasHistoryButtonModal = () => {
|
||||
const isStaging = useAppSelector(isStagingSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const acceptCallback = useCallback(
|
||||
() => dispatch(clearCanvasHistory()),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
title={t('unifiedCanvas.clearCanvasHistory')}
|
||||
acceptCallback={() => dispatch(clearCanvasHistory())}
|
||||
acceptCallback={acceptCallback}
|
||||
acceptButtonText={t('unifiedCanvas.clearHistory')}
|
||||
triggerComponent={
|
||||
<IAIButton size="sm" leftIcon={<FaTrash />} isDisabled={isStaging}>
|
||||
|
||||
@@ -20,7 +20,8 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -95,18 +96,35 @@ const IAICanvasMaskOptions = () => {
|
||||
[isMaskEnabled]
|
||||
);
|
||||
|
||||
const handleToggleMaskLayer = () => {
|
||||
const handleToggleMaskLayer = useCallback(() => {
|
||||
dispatch(setLayer(layer === 'mask' ? 'base' : 'mask'));
|
||||
};
|
||||
}, [dispatch, layer]);
|
||||
|
||||
const handleClearMask = () => dispatch(clearMask());
|
||||
const handleClearMask = useCallback(() => {
|
||||
dispatch(clearMask());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleToggleEnableMask = () =>
|
||||
const handleToggleEnableMask = useCallback(() => {
|
||||
dispatch(setIsMaskEnabled(!isMaskEnabled));
|
||||
}, [dispatch, isMaskEnabled]);
|
||||
|
||||
const handleSaveMask = async () => {
|
||||
const handleSaveMask = useCallback(async () => {
|
||||
dispatch(canvasMaskSavedToGallery());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleChangePreserveMaskedArea = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(setShouldPreserveMaskedArea(e.target.checked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeMaskColor = useCallback(
|
||||
(newColor: RgbaColor) => {
|
||||
dispatch(setMaskColor(newColor));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
@@ -131,15 +149,10 @@ const IAICanvasMaskOptions = () => {
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.preserveMaskedArea')}
|
||||
isChecked={shouldPreserveMaskedArea}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldPreserveMaskedArea(e.target.checked))
|
||||
}
|
||||
onChange={handleChangePreserveMaskedArea}
|
||||
/>
|
||||
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
|
||||
<IAIColorPicker
|
||||
color={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
<IAIColorPicker color={maskColor} onChange={handleChangeMaskColor} />
|
||||
</Box>
|
||||
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||
Save Mask
|
||||
|
||||
@@ -10,6 +10,7 @@ import { redo } from 'features/canvas/store/canvasSlice';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const canvasRedoSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@@ -34,9 +35,9 @@ export default function IAICanvasRedoButton() {
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleRedo = () => {
|
||||
const handleRedo = useCallback(() => {
|
||||
dispatch(redo());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
useHotkeys(
|
||||
['meta+shift+z', 'ctrl+shift+z', 'control+y', 'meta+y'],
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaWrench } from 'react-icons/fa';
|
||||
@@ -86,8 +86,52 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
[shouldSnapToGrid]
|
||||
);
|
||||
|
||||
const handleChangeShouldSnapToGrid = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldSnapToGrid(e.target.checked));
|
||||
const handleChangeShouldSnapToGrid = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldSnapToGrid(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeShouldShowIntermediates = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldShowIntermediates(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldShowGrid = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldShowGrid(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldDarkenOutsideBoundingBox = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldAutoSave = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldAutoSave(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldCropToBoundingBoxOnSave = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldRestrictStrokesToBox = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRestrictStrokesToBox(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldShowCanvasDebugInfo = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldShowCanvasDebugInfo(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldAntialias = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldAntialias(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
@@ -104,14 +148,12 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.showIntermediates')}
|
||||
isChecked={shouldShowIntermediates}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldShowIntermediates(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeShouldShowIntermediates}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.showGrid')}
|
||||
isChecked={shouldShowGrid}
|
||||
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
|
||||
onChange={handleChangeShouldShowGrid}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.snapToGrid')}
|
||||
@@ -121,41 +163,33 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.darkenOutsideSelection')}
|
||||
isChecked={shouldDarkenOutsideBoundingBox}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeShouldDarkenOutsideBoundingBox}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.autoSaveToGallery')}
|
||||
isChecked={shouldAutoSave}
|
||||
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
|
||||
onChange={handleChangeShouldAutoSave}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.saveBoxRegionOnly')}
|
||||
isChecked={shouldCropToBoundingBoxOnSave}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeShouldCropToBoundingBoxOnSave}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.limitStrokesToBox')}
|
||||
isChecked={shouldRestrictStrokesToBox}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldRestrictStrokesToBox(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeShouldRestrictStrokesToBox}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.showCanvasDebugInfo')}
|
||||
isChecked={shouldShowCanvasDebugInfo}
|
||||
onChange={(e) =>
|
||||
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeShouldShowCanvasDebugInfo}
|
||||
/>
|
||||
|
||||
<IAISimpleCheckbox
|
||||
label={t('unifiedCanvas.antialiasing')}
|
||||
isChecked={shouldAntialias}
|
||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||
onChange={handleChangeShouldAntialias}
|
||||
/>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
</Flex>
|
||||
|
||||
@@ -15,7 +15,8 @@ import {
|
||||
setTool,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { clamp, isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -172,11 +173,33 @@ const IAICanvasToolChooserOptions = () => {
|
||||
[brushColor]
|
||||
);
|
||||
|
||||
const handleSelectBrushTool = () => dispatch(setTool('brush'));
|
||||
const handleSelectEraserTool = () => dispatch(setTool('eraser'));
|
||||
const handleSelectColorPickerTool = () => dispatch(setTool('colorPicker'));
|
||||
const handleFillRect = () => dispatch(addFillRect());
|
||||
const handleEraseBoundingBox = () => dispatch(addEraseRect());
|
||||
const handleSelectBrushTool = useCallback(() => {
|
||||
dispatch(setTool('brush'));
|
||||
}, [dispatch]);
|
||||
const handleSelectEraserTool = useCallback(() => {
|
||||
dispatch(setTool('eraser'));
|
||||
}, [dispatch]);
|
||||
const handleSelectColorPickerTool = useCallback(() => {
|
||||
dispatch(setTool('colorPicker'));
|
||||
}, [dispatch]);
|
||||
const handleFillRect = useCallback(() => {
|
||||
dispatch(addFillRect());
|
||||
}, [dispatch]);
|
||||
const handleEraseBoundingBox = useCallback(() => {
|
||||
dispatch(addEraseRect());
|
||||
}, [dispatch]);
|
||||
const handleChangeBrushSize = useCallback(
|
||||
(newSize: number) => {
|
||||
dispatch(setBrushSize(newSize));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeBrushColor = useCallback(
|
||||
(newColor: RgbaColor) => {
|
||||
dispatch(setBrushColor(newColor));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ButtonGroup isAttached>
|
||||
@@ -233,7 +256,7 @@ const IAICanvasToolChooserOptions = () => {
|
||||
label={t('unifiedCanvas.brushSize')}
|
||||
value={brushSize}
|
||||
withInput
|
||||
onChange={(newSize) => dispatch(setBrushSize(newSize))}
|
||||
onChange={handleChangeBrushSize}
|
||||
sliderNumberInputProps={{ max: 500 }}
|
||||
/>
|
||||
</Flex>
|
||||
@@ -247,7 +270,7 @@ const IAICanvasToolChooserOptions = () => {
|
||||
<IAIColorPicker
|
||||
withNumberInput={true}
|
||||
color={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
onChange={handleChangeBrushColor}
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
@@ -25,9 +25,9 @@ import {
|
||||
LAYER_NAMES_DICT,
|
||||
} from 'features/canvas/store/canvasTypes';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -151,7 +151,9 @@ const IAICanvasToolbar = () => {
|
||||
[canvasBaseLayer]
|
||||
);
|
||||
|
||||
const handleSelectMoveTool = () => dispatch(setTool('move'));
|
||||
const handleSelectMoveTool = useCallback(() => {
|
||||
dispatch(setTool('move'));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickResetCanvasView = useSingleAndDoubleClick(
|
||||
() => handleResetCanvasView(false),
|
||||
@@ -174,36 +176,39 @@ const IAICanvasToolbar = () => {
|
||||
);
|
||||
};
|
||||
|
||||
const handleResetCanvas = () => {
|
||||
const handleResetCanvas = useCallback(() => {
|
||||
dispatch(resetCanvas());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleMergeVisible = () => {
|
||||
const handleMergeVisible = useCallback(() => {
|
||||
dispatch(canvasMerged());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleSaveToGallery = () => {
|
||||
const handleSaveToGallery = useCallback(() => {
|
||||
dispatch(canvasSavedToGallery());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleCopyImageToClipboard = () => {
|
||||
const handleCopyImageToClipboard = useCallback(() => {
|
||||
if (!isClipboardAPIAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(canvasCopiedToClipboard());
|
||||
};
|
||||
}, [dispatch, isClipboardAPIAvailable]);
|
||||
|
||||
const handleDownloadAsImage = () => {
|
||||
const handleDownloadAsImage = useCallback(() => {
|
||||
dispatch(canvasDownloadedAsImage());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleChangeLayer = (v: string) => {
|
||||
const newLayer = v as CanvasLayer;
|
||||
dispatch(setLayer(newLayer));
|
||||
if (newLayer === 'mask' && !isMaskEnabled) {
|
||||
dispatch(setIsMaskEnabled(true));
|
||||
}
|
||||
};
|
||||
const handleChangeLayer = useCallback(
|
||||
(v: string) => {
|
||||
const newLayer = v as CanvasLayer;
|
||||
dispatch(setLayer(newLayer));
|
||||
if (newLayer === 'mask' && !isMaskEnabled) {
|
||||
dispatch(setIsMaskEnabled(true));
|
||||
}
|
||||
},
|
||||
[dispatch, isMaskEnabled]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
||||
@@ -10,6 +10,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const canvasUndoSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@@ -35,9 +36,9 @@ export default function IAICanvasUndoButton() {
|
||||
|
||||
const { canUndo, activeTabName } = useAppSelector(canvasUndoSelector);
|
||||
|
||||
const handleUndo = () => {
|
||||
const handleUndo = useCallback(() => {
|
||||
dispatch(undo());
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
useHotkeys(
|
||||
['meta+z', 'ctrl+z'],
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import Konva from 'konva';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
|
||||
/**
|
||||
* Converts a Konva node to a dataURL
|
||||
* @param node - The Konva node to convert to a dataURL
|
||||
* @param boundingBox - The bounding box to crop to
|
||||
* @returns A dataURL of the node cropped to the bounding box
|
||||
*/
|
||||
export const konvaNodeToDataURL = (
|
||||
node: Konva.Node,
|
||||
boundingBox: IRect
|
||||
): string => {
|
||||
// get a dataURL of the bbox'd region
|
||||
return node.toDataURL(boundingBox);
|
||||
};
|
||||
@@ -87,6 +87,11 @@ const ChangeBoardModal = () => {
|
||||
selectedBoard,
|
||||
]);
|
||||
|
||||
const handleSetSelectedBoard = useCallback(
|
||||
(v: string | null) => setSelectedBoard(v),
|
||||
[]
|
||||
);
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
return (
|
||||
@@ -113,7 +118,7 @@ const ChangeBoardModal = () => {
|
||||
isFetching ? t('boards.loading') : t('boards.selectBoard')
|
||||
}
|
||||
disabled={isFetching}
|
||||
onChange={(v) => setSelectedBoard(v)}
|
||||
onChange={handleSetSelectedBoard}
|
||||
value={selectedBoard}
|
||||
data={data}
|
||||
/>
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
|
||||
import { controlAdapterImageProcessed } from '../store/actions';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const ControlAdapterPreprocessButton = ({ id }: Props) => {
|
||||
const controlImage = useControlAdapterControlImage(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToEnqueue();
|
||||
|
||||
const handleProcess = useCallback(() => {
|
||||
dispatch(
|
||||
controlAdapterImageProcessed({
|
||||
id,
|
||||
})
|
||||
);
|
||||
}, [id, dispatch]);
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleProcess}
|
||||
isDisabled={Boolean(!controlImage) || !isReady}
|
||||
>
|
||||
Preprocess
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlAdapterPreprocessButton);
|
||||
@@ -1 +0,0 @@
|
||||
//
|
||||
@@ -14,9 +14,9 @@ import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectI
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
onSelect: (v: string) => void;
|
||||
@@ -78,6 +78,13 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const filterFunc = useCallback(
|
||||
(value: string, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
initialFocusRef={inputRef}
|
||||
@@ -127,12 +134,7 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
onDropdownClose={onClose}
|
||||
filter={(value, item: SelectItem) =>
|
||||
item.label
|
||||
?.toLowerCase()
|
||||
.includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
filter={filterFunc}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -60,6 +60,13 @@ const BoardAutoAddSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const filterFunc = useCallback(
|
||||
(value: string, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('boards.autoAddBoard')}
|
||||
@@ -71,10 +78,7 @@ const BoardAutoAddSelect = () => {
|
||||
nothingFound={t('boards.noMatching')}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={!hasBoards || autoAssignBoardOnClick}
|
||||
filter={(value, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
filter={filterFunc}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -90,6 +90,50 @@ const BoardContextMenu = ({
|
||||
e.preventDefault();
|
||||
}, []);
|
||||
|
||||
const renderMenuFunc = useCallback(
|
||||
() => (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<MenuGroup title={boardName}>
|
||||
<MenuItem
|
||||
icon={<FaPlus />}
|
||||
isDisabled={isAutoAdd || autoAssignBoardOnClick}
|
||||
onClick={handleSetAutoAdd}
|
||||
>
|
||||
{t('boards.menuItemAutoAdd')}
|
||||
</MenuItem>
|
||||
{isBulkDownloadEnabled && (
|
||||
<MenuItem icon={<FaDownload />} onClickCapture={handleBulkDownload}>
|
||||
{t('boards.downloadBoard')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{!board && <NoBoardContextMenuItems />}
|
||||
{board && (
|
||||
<GalleryBoardContextMenuItems
|
||||
board={board}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
/>
|
||||
)}
|
||||
</MenuGroup>
|
||||
</MenuList>
|
||||
),
|
||||
[
|
||||
autoAssignBoardOnClick,
|
||||
board,
|
||||
boardName,
|
||||
handleBulkDownload,
|
||||
handleSetAutoAdd,
|
||||
isAutoAdd,
|
||||
isBulkDownloadEnabled,
|
||||
setBoardToDelete,
|
||||
skipEvent,
|
||||
t,
|
||||
]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
@@ -97,38 +141,7 @@ const BoardContextMenu = ({
|
||||
bg: 'transparent',
|
||||
_hover: { bg: 'transparent' },
|
||||
}}
|
||||
renderMenu={() => (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<MenuGroup title={boardName}>
|
||||
<MenuItem
|
||||
icon={<FaPlus />}
|
||||
isDisabled={isAutoAdd || autoAssignBoardOnClick}
|
||||
onClick={handleSetAutoAdd}
|
||||
>
|
||||
{t('boards.menuItemAutoAdd')}
|
||||
</MenuItem>
|
||||
{isBulkDownloadEnabled && (
|
||||
<MenuItem
|
||||
icon={<FaDownload />}
|
||||
onClickCapture={handleBulkDownload}
|
||||
>
|
||||
{t('boards.downloadBoard')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{!board && <NoBoardContextMenuItems />}
|
||||
{board && (
|
||||
<GalleryBoardContextMenuItems
|
||||
board={board}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
/>
|
||||
)}
|
||||
</MenuGroup>
|
||||
</MenuList>
|
||||
)}
|
||||
renderMenu={renderMenuFunc}
|
||||
>
|
||||
{children}
|
||||
</IAIContextMenu>
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import { As, Badge, Flex } from '@chakra-ui/react';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { ReactNode, memo } from 'react';
|
||||
import BoardContextMenu from '../BoardContextMenu';
|
||||
|
||||
type GenericBoardProps = {
|
||||
board_id: BoardId;
|
||||
droppableData?: TypesafeDroppableData;
|
||||
onClick: () => void;
|
||||
isSelected: boolean;
|
||||
icon: As;
|
||||
label: string;
|
||||
dropLabel?: ReactNode;
|
||||
badgeCount?: number;
|
||||
};
|
||||
|
||||
export const formatBadgeCount = (count: number) =>
|
||||
Intl.NumberFormat('en-US', {
|
||||
notation: 'compact',
|
||||
maximumFractionDigits: 1,
|
||||
}).format(count);
|
||||
|
||||
const GenericBoard = (props: GenericBoardProps) => {
|
||||
const {
|
||||
board_id,
|
||||
droppableData,
|
||||
onClick,
|
||||
isSelected,
|
||||
icon,
|
||||
label,
|
||||
badgeCount,
|
||||
dropLabel,
|
||||
} = props;
|
||||
|
||||
return (
|
||||
<BoardContextMenu board_id={board_id}>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
ref={ref}
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
onClick={onClick}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<IAINoContentFallback
|
||||
boxSize={8}
|
||||
icon={icon}
|
||||
sx={{
|
||||
border: '2px solid var(--invokeai-colors-base-200)',
|
||||
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
|
||||
}}
|
||||
/>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
insetInlineEnd: 0,
|
||||
top: 0,
|
||||
p: 1,
|
||||
}}
|
||||
>
|
||||
{badgeCount !== undefined && (
|
||||
<Badge variant="solid">{formatBadgeCount(badgeCount)}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<IAIDroppable data={droppableData} dropLabel={dropLabel} />
|
||||
</Flex>
|
||||
<Flex
|
||||
sx={{
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'sm',
|
||||
color: isSelected ? 'base.900' : 'base.700',
|
||||
_dark: { color: isSelected ? 'base.50' : 'base.200' },
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</BoardContextMenu>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(GenericBoard);
|
||||
@@ -1,53 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
type Props = {
|
||||
board_id: 'images' | 'assets' | 'no_board';
|
||||
};
|
||||
|
||||
const SystemBoardButton = ({ board_id }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
({ gallery }) => {
|
||||
const { selectedBoardId } = gallery;
|
||||
return { isSelected: selectedBoardId === board_id };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[board_id]
|
||||
);
|
||||
|
||||
const { isSelected } = useAppSelector(selector);
|
||||
|
||||
const boardName = useBoardName(board_id);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(boardIdSelected({ boardId: board_id }));
|
||||
}, [board_id, dispatch]);
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
onClick={handleClick}
|
||||
size="sm"
|
||||
isChecked={isSelected}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
{boardName}
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SystemBoardButton);
|
||||
@@ -1,22 +0,0 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaEyeSlash } from 'react-icons/fa';
|
||||
|
||||
const CurrentImageHidden = () => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<FaEyeSlash fontSize="25vh" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CurrentImageHidden);
|
||||
@@ -61,6 +61,12 @@ const GallerySettingsPopover = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeAutoAssignBoardOnClick = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
@@ -91,9 +97,7 @@ const GallerySettingsPopover = () => {
|
||||
<IAISimpleCheckbox
|
||||
label={t('gallery.autoAssignBoardOnClick')}
|
||||
isChecked={autoAssignBoardOnClick}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(autoAssignBoardOnClickChanged(e.target.checked))
|
||||
}
|
||||
onChange={handleChangeAutoAssignBoardOnClick}
|
||||
/>
|
||||
<BoardAutoAddSelect />
|
||||
</Flex>
|
||||
|
||||
@@ -35,6 +35,34 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
e.preventDefault();
|
||||
}, []);
|
||||
|
||||
const renderMenuFunc = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (selectionCount > 1) {
|
||||
return (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<MultipleSelectionMenuItems />
|
||||
</MenuList>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||
</MenuList>
|
||||
);
|
||||
}, [imageDTO, selectionCount, skipEvent]);
|
||||
|
||||
return (
|
||||
<IAIContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
@@ -42,33 +70,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
bg: 'transparent',
|
||||
_hover: { bg: 'transparent' },
|
||||
}}
|
||||
renderMenu={() => {
|
||||
if (!imageDTO) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (selectionCount > 1) {
|
||||
return (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<MultipleSelectionMenuItems />
|
||||
</MenuList>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
onContextMenu={skipEvent}
|
||||
>
|
||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||
</MenuList>
|
||||
);
|
||||
}}
|
||||
renderMenu={renderMenuFunc}
|
||||
>
|
||||
{children}
|
||||
</IAIContextMenu>
|
||||
|
||||
@@ -13,7 +13,7 @@ import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
type ImageFallbackSpinnerProps = SpinnerProps;
|
||||
|
||||
const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
|
||||
const { size = 'xl', ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
minH: 36,
|
||||
minW: 36,
|
||||
}}
|
||||
>
|
||||
<Spinner size={size} {...rest} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageFallbackSpinner);
|
||||
@@ -20,6 +20,7 @@ import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||
import GalleryImage from './GalleryImage';
|
||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||
import ImageGridListContainer from './ImageGridListContainer';
|
||||
import { EntityId } from '@reduxjs/toolkit';
|
||||
|
||||
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
||||
defer: true,
|
||||
@@ -71,6 +72,13 @@ const GalleryImageGrid = () => {
|
||||
});
|
||||
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
|
||||
|
||||
const itemContentFunc = useCallback(
|
||||
(index: number, imageName: EntityId) => (
|
||||
<GalleryImage key={imageName} imageName={imageName as string} />
|
||||
),
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Initialize the gallery's custom scrollbar
|
||||
const { current: root } = rootRef;
|
||||
@@ -131,9 +139,7 @@ const GalleryImageGrid = () => {
|
||||
List: ImageGridListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, imageName) => (
|
||||
<GalleryImage key={imageName} imageName={imageName as string} />
|
||||
)}
|
||||
itemContent={itemContentFunc}
|
||||
/>
|
||||
</Box>
|
||||
<IAIButton
|
||||
|
||||
@@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
key={index}
|
||||
label="LoRA"
|
||||
value={`${lora.lora.model_name} - ${lora.weight}`}
|
||||
onClick={() => handleRecallLoRA(lora)}
|
||||
onClick={handleRecallLoRA.bind(null, lora)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -289,7 +289,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||
onClick={() => handleRecallControlNet(controlnet)}
|
||||
onClick={handleRecallControlNet.bind(null, controlnet)}
|
||||
/>
|
||||
))}
|
||||
{validIPAdapters.map((ipAdapter, index) => (
|
||||
@@ -297,7 +297,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
key={index}
|
||||
label="IP Adapter"
|
||||
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
|
||||
onClick={() => handleRecallIPAdapter(ipAdapter)}
|
||||
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
|
||||
/>
|
||||
))}
|
||||
{validT2IAdapters.map((t2iAdapter, index) => (
|
||||
@@ -305,7 +305,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
key={index}
|
||||
label="T2I Adapter"
|
||||
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
|
||||
onClick={() => handleRecallT2IAdapter(t2iAdapter)}
|
||||
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
@@ -27,6 +27,11 @@ const ImageMetadataItem = ({
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleCopy = useCallback(
|
||||
() => navigator.clipboard.writeText(value.toString()),
|
||||
[value]
|
||||
);
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
@@ -53,7 +58,7 @@ const ImageMetadataItem = ({
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(value.toString())}
|
||||
onClick={handleCopy}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
@@ -76,6 +76,13 @@ const ParamLoRASelect = () => {
|
||||
[dispatch, loraModels?.entities]
|
||||
);
|
||||
|
||||
const filterFunc = useCallback(
|
||||
(value: string, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
|
||||
[]
|
||||
);
|
||||
|
||||
if (loraModels?.ids.length === 0) {
|
||||
return (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
@@ -94,10 +101,7 @@ const ParamLoRASelect = () => {
|
||||
nothingFound="No matching LoRAs"
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
filter={filterFunc}
|
||||
onChange={handleChange}
|
||||
data-testid="add-lora"
|
||||
/>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import AdvancedAddModels from './AdvancedAddModels';
|
||||
import SimpleAddModels from './SimpleAddModels';
|
||||
|
||||
@@ -8,6 +8,11 @@ export default function AddModels() {
|
||||
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
|
||||
'simple'
|
||||
);
|
||||
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
|
||||
const handleAddModelAdvanced = useCallback(
|
||||
() => setAddModelMode('advanced'),
|
||||
[]
|
||||
);
|
||||
return (
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
@@ -20,14 +25,14 @@ export default function AddModels() {
|
||||
<IAIButton
|
||||
size="sm"
|
||||
isChecked={addModelMode == 'simple'}
|
||||
onClick={() => setAddModelMode('simple')}
|
||||
onClick={handleAddModelSimple}
|
||||
>
|
||||
Simple
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
isChecked={addModelMode == 'advanced'}
|
||||
onClick={() => setAddModelMode('advanced')}
|
||||
onClick={handleAddModelAdvanced}
|
||||
>
|
||||
Advanced
|
||||
</IAIButton>
|
||||
@@ -6,7 +6,7 @@ import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useState } from 'react';
|
||||
import { FocusEventHandler, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { CheckpointModelConfig } from 'services/api/types';
|
||||
@@ -83,6 +83,27 @@ export default function AdvancedAddCheckpoint(
|
||||
});
|
||||
};
|
||||
|
||||
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
|
||||
useCallback(
|
||||
(e) => {
|
||||
if (advancedAddCheckpointForm.values['model_name'] === '') {
|
||||
const modelName = getModelName(e.currentTarget.value);
|
||||
if (modelName) {
|
||||
advancedAddCheckpointForm.setFieldValue(
|
||||
'model_name',
|
||||
modelName as string
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
[advancedAddCheckpointForm]
|
||||
);
|
||||
|
||||
const handleChangeUseCustomConfig = useCallback(
|
||||
() => setUseCustomConfig((prev) => !prev),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<form
|
||||
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
|
||||
@@ -104,17 +125,7 @@ export default function AdvancedAddCheckpoint(
|
||||
label={t('modelManager.modelLocation')}
|
||||
required
|
||||
{...advancedAddCheckpointForm.getInputProps('path')}
|
||||
onBlur={(e) => {
|
||||
if (advancedAddCheckpointForm.values['model_name'] === '') {
|
||||
const modelName = getModelName(e.currentTarget.value);
|
||||
if (modelName) {
|
||||
advancedAddCheckpointForm.setFieldValue(
|
||||
'model_name',
|
||||
modelName as string
|
||||
);
|
||||
}
|
||||
}
|
||||
}}
|
||||
onBlur={handleBlurModelLocation}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.description')}
|
||||
@@ -144,7 +155,7 @@ export default function AdvancedAddCheckpoint(
|
||||
)}
|
||||
<IAISimpleCheckbox
|
||||
isChecked={useCustomConfig}
|
||||
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||
onChange={handleChangeUseCustomConfig}
|
||||
label={t('modelManager.useCustomConfig')}
|
||||
/>
|
||||
<IAIButton mt={2} type="submit">
|
||||
@@ -12,6 +12,7 @@ import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||
import { getModelName } from './util';
|
||||
import { FocusEventHandler, useCallback } from 'react';
|
||||
|
||||
type AdvancedAddDiffusersProps = {
|
||||
model_path?: string;
|
||||
@@ -74,6 +75,22 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
|
||||
});
|
||||
};
|
||||
|
||||
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
|
||||
useCallback(
|
||||
(e) => {
|
||||
if (advancedAddDiffusersForm.values['model_name'] === '') {
|
||||
const modelName = getModelName(e.currentTarget.value, false);
|
||||
if (modelName) {
|
||||
advancedAddDiffusersForm.setFieldValue(
|
||||
'model_name',
|
||||
modelName as string
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
[advancedAddDiffusersForm]
|
||||
);
|
||||
|
||||
return (
|
||||
<form
|
||||
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
|
||||
@@ -96,17 +113,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
|
||||
label={t('modelManager.modelLocation')}
|
||||
placeholder={t('modelManager.modelLocationValidationMsg')}
|
||||
{...advancedAddDiffusersForm.getInputProps('path')}
|
||||
onBlur={(e) => {
|
||||
if (advancedAddDiffusersForm.values['model_name'] === '') {
|
||||
const modelName = getModelName(e.currentTarget.value, false);
|
||||
if (modelName) {
|
||||
advancedAddDiffusersForm.setFieldValue(
|
||||
'model_name',
|
||||
modelName as string
|
||||
);
|
||||
}
|
||||
}
|
||||
}}
|
||||
onBlur={handleBlurModelLocation}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.description')}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -18,6 +18,12 @@ export default function AdvancedAddModels() {
|
||||
useState<ManualAddMode>('diffusers');
|
||||
|
||||
const { t } = useTranslation();
|
||||
const handleChange = useCallback((v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v as ManualAddMode);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} width="100%">
|
||||
@@ -25,12 +31,7 @@ export default function AdvancedAddModels() {
|
||||
label={t('modelManager.modelType')}
|
||||
value={advancedAddMode}
|
||||
data={advancedAddModeData}
|
||||
onChange={(v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v as ManualAddMode);
|
||||
}}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
|
||||
<Flex
|
||||
@@ -92,6 +92,11 @@ export default function FoundModelsList() {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleClickSetAdvanced = useCallback(
|
||||
(model: string) => dispatch(setAdvancedAddScanModel(model)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const renderModels = ({
|
||||
models,
|
||||
showActions = true,
|
||||
@@ -140,7 +145,7 @@ export default function FoundModelsList() {
|
||||
{t('modelManager.quickAdd')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={() => dispatch(setAdvancedAddScanModel(model))}
|
||||
onClick={handleClickSetAdvanced.bind(null, model)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('modelManager.advanced')}
|
||||
@@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { motion } from 'framer-motion';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { FaTimes } from 'react-icons/fa';
|
||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||
@@ -35,6 +35,23 @@ export default function ScanAdvancedAddModels() {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickSetAdvanced = useCallback(
|
||||
() => dispatch(setAdvancedAddScanModel(null)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeAddMode = useCallback((v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v as ManualAddMode);
|
||||
if (v === 'checkpoint') {
|
||||
setIsCheckpoint(true);
|
||||
} else {
|
||||
setIsCheckpoint(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
if (!advancedAddScanModel) {
|
||||
return null;
|
||||
}
|
||||
@@ -68,7 +85,7 @@ export default function ScanAdvancedAddModels() {
|
||||
<IAIIconButton
|
||||
icon={<FaTimes />}
|
||||
aria-label={t('modelManager.closeAdvanced')}
|
||||
onClick={() => dispatch(setAdvancedAddScanModel(null))}
|
||||
onClick={handleClickSetAdvanced}
|
||||
size="sm"
|
||||
/>
|
||||
</Flex>
|
||||
@@ -76,17 +93,7 @@ export default function ScanAdvancedAddModels() {
|
||||
label={t('modelManager.modelType')}
|
||||
value={advancedAddMode}
|
||||
data={advancedAddModeData}
|
||||
onChange={(v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v as ManualAddMode);
|
||||
if (v === 'checkpoint') {
|
||||
setIsCheckpoint(true);
|
||||
} else {
|
||||
setIsCheckpoint(false);
|
||||
}
|
||||
}}
|
||||
onChange={handleChangeAddMode}
|
||||
/>
|
||||
{isCheckpoint ? (
|
||||
<AdvancedAddCheckpoint
|
||||
@@ -42,9 +42,14 @@ function SearchFolderForm() {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const scanAgainHandler = () => {
|
||||
const scanAgainHandler = useCallback(() => {
|
||||
refetchFoundModels();
|
||||
};
|
||||
}, [refetchFoundModels]);
|
||||
|
||||
const handleClickClearCheckpointFolder = useCallback(() => {
|
||||
dispatch(setSearchFolder(null));
|
||||
dispatch(setAdvancedAddScanModel(null));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<form
|
||||
@@ -123,10 +128,7 @@ function SearchFolderForm() {
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
dispatch(setSearchFolder(null));
|
||||
dispatch(setAdvancedAddScanModel(null));
|
||||
}}
|
||||
onClick={handleClickClearCheckpointFolder}
|
||||
isDisabled={!searchFolder}
|
||||
colorScheme="red"
|
||||
/>
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import AddModels from './AddModelsPanel/AddModels';
|
||||
import ScanModels from './AddModelsPanel/ScanModels';
|
||||
@@ -11,11 +11,14 @@ export default function ImportModelsPanel() {
|
||||
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleClickAddTab = useCallback(() => setAddModelTab('add'), []);
|
||||
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4}>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => setAddModelTab('add')}
|
||||
onClick={handleClickAddTab}
|
||||
isChecked={addModelTab == 'add'}
|
||||
size="sm"
|
||||
width="100%"
|
||||
@@ -23,7 +26,7 @@ export default function ImportModelsPanel() {
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={() => setAddModelTab('scan')}
|
||||
onClick={handleClickScanTab}
|
||||
isChecked={addModelTab == 'scan'}
|
||||
size="sm"
|
||||
width="100%"
|
||||
@@ -9,7 +9,7 @@ import IAISlider from 'common/components/IAISlider';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { ChangeEvent, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
@@ -94,13 +94,58 @@ export default function MergeModelsPanel() {
|
||||
modelsMap[baseModel as keyof typeof modelsMap]
|
||||
).filter((model) => model !== modelOne && model !== modelTwo);
|
||||
|
||||
const handleBaseModelChange = (v: string) => {
|
||||
const handleBaseModelChange = useCallback((v: string) => {
|
||||
setBaseModel(v as BaseModelType);
|
||||
setModelOne(null);
|
||||
setModelTwo(null);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const mergeModelsHandler = () => {
|
||||
const handleChangeModelOne = useCallback((v: string) => {
|
||||
setModelOne(v);
|
||||
}, []);
|
||||
const handleChangeModelTwo = useCallback((v: string) => {
|
||||
setModelTwo(v);
|
||||
}, []);
|
||||
const handleChangeModelThree = useCallback((v: string) => {
|
||||
if (!v) {
|
||||
setModelThree(null);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree(v);
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}, []);
|
||||
const handleChangeMergedModelName = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => setMergedModelName(e.target.value),
|
||||
[]
|
||||
);
|
||||
const handleChangeModelMergeAlpha = useCallback(
|
||||
(v: number) => setModelMergeAlpha(v),
|
||||
[]
|
||||
);
|
||||
const handleResetModelMergeAlpha = useCallback(
|
||||
() => setModelMergeAlpha(0.5),
|
||||
[]
|
||||
);
|
||||
const handleChangeMergeInterp = useCallback(
|
||||
(v: MergeInterpolationMethods) => setModelMergeInterp(v),
|
||||
[]
|
||||
);
|
||||
const handleChangeMergeSaveLocType = useCallback(
|
||||
(v: 'root' | 'custom') => setModelMergeSaveLocType(v),
|
||||
[]
|
||||
);
|
||||
const handleChangeMergeCustomSaveLoc = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
setModelMergeCustomSaveLoc(e.target.value),
|
||||
[]
|
||||
);
|
||||
const handleChangeModelMergeForce = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => setModelMergeForce(e.target.checked),
|
||||
[]
|
||||
);
|
||||
|
||||
const mergeModelsHandler = useCallback(() => {
|
||||
const models_names: string[] = [];
|
||||
|
||||
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
|
||||
@@ -150,7 +195,21 @@ export default function MergeModelsPanel() {
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
}, [
|
||||
baseModel,
|
||||
dispatch,
|
||||
mergeModels,
|
||||
mergedModelName,
|
||||
modelMergeAlpha,
|
||||
modelMergeCustomSaveLoc,
|
||||
modelMergeForce,
|
||||
modelMergeInterp,
|
||||
modelMergeSaveLocType,
|
||||
modelOne,
|
||||
modelThree,
|
||||
modelTwo,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
@@ -180,7 +239,7 @@ export default function MergeModelsPanel() {
|
||||
value={modelOne}
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
data={modelOneList}
|
||||
onChange={(v) => setModelOne(v)}
|
||||
onChange={handleChangeModelOne}
|
||||
/>
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.modelTwo')}
|
||||
@@ -188,7 +247,7 @@ export default function MergeModelsPanel() {
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
value={modelTwo}
|
||||
data={modelTwoList}
|
||||
onChange={(v) => setModelTwo(v)}
|
||||
onChange={handleChangeModelTwo}
|
||||
/>
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.modelThree')}
|
||||
@@ -196,22 +255,14 @@ export default function MergeModelsPanel() {
|
||||
w="100%"
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
clearable
|
||||
onChange={(v) => {
|
||||
if (!v) {
|
||||
setModelThree(null);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree(v);
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}}
|
||||
onChange={handleChangeModelThree}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelName')}
|
||||
value={mergedModelName}
|
||||
onChange={(e) => setMergedModelName(e.target.value)}
|
||||
onChange={handleChangeMergedModelName}
|
||||
/>
|
||||
|
||||
<Flex
|
||||
@@ -232,10 +283,10 @@ export default function MergeModelsPanel() {
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={modelMergeAlpha}
|
||||
onChange={(v) => setModelMergeAlpha(v)}
|
||||
onChange={handleChangeModelMergeAlpha}
|
||||
withInput
|
||||
withReset
|
||||
handleReset={() => setModelMergeAlpha(0.5)}
|
||||
handleReset={handleResetModelMergeAlpha}
|
||||
withSliderMarks
|
||||
/>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
@@ -257,10 +308,7 @@ export default function MergeModelsPanel() {
|
||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||
{t('modelManager.interpolationType')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeInterp}
|
||||
onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
|
||||
>
|
||||
<RadioGroup value={modelMergeInterp} onChange={handleChangeMergeInterp}>
|
||||
<Flex columnGap={4}>
|
||||
{modelThree === null ? (
|
||||
<>
|
||||
@@ -305,7 +353,7 @@ export default function MergeModelsPanel() {
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeSaveLocType}
|
||||
onChange={(v: 'root' | 'custom') => setModelMergeSaveLocType(v)}
|
||||
onChange={handleChangeMergeSaveLocType}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
<Radio value="root">
|
||||
@@ -323,7 +371,7 @@ export default function MergeModelsPanel() {
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelCustomSaveLocation')}
|
||||
value={modelMergeCustomSaveLoc}
|
||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||
onChange={handleChangeMergeCustomSaveLoc}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
@@ -331,7 +379,7 @@ export default function MergeModelsPanel() {
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.ignoreMismatch')}
|
||||
isChecked={modelMergeForce}
|
||||
onChange={(e) => setModelMergeForce(e.target.checked)}
|
||||
onChange={handleChangeModelMergeForce}
|
||||
fontWeight="500"
|
||||
/>
|
||||
|
||||
@@ -59,6 +59,11 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
},
|
||||
});
|
||||
|
||||
const handleChangeUseCustomConfig = useCallback(
|
||||
() => setUseCustomConfig((prev) => !prev),
|
||||
[]
|
||||
);
|
||||
|
||||
const editModelFormSubmitHandler = useCallback(
|
||||
(values: CheckpointModelConfig) => {
|
||||
const responseBody = {
|
||||
@@ -181,7 +186,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
)}
|
||||
<IAISimpleCheckbox
|
||||
isChecked={useCustomConfig}
|
||||
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||
onChange={handleChangeUseCustomConfig}
|
||||
label="Use Custom Config"
|
||||
/>
|
||||
</Flex>
|
||||
@@ -14,7 +14,7 @@ import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { ChangeEvent, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
@@ -42,11 +42,21 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
}, [model]);
|
||||
|
||||
const modelConvertCancelHandler = () => {
|
||||
const modelConvertCancelHandler = useCallback(() => {
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
};
|
||||
}, []);
|
||||
|
||||
const modelConvertHandler = () => {
|
||||
const handleChangeSaveLocation = useCallback((v: string) => {
|
||||
setSaveLocation(v as SaveLocation);
|
||||
}, []);
|
||||
const handleChangeCustomSaveLocation = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setCustomSaveLocation(e.target.value);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
const queryArg = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
@@ -101,7 +111,15 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
)
|
||||
);
|
||||
});
|
||||
};
|
||||
}, [
|
||||
convertModel,
|
||||
customSaveLocation,
|
||||
dispatch,
|
||||
model.base_model,
|
||||
model.model_name,
|
||||
saveLocation,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
@@ -137,10 +155,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
<Text fontWeight="600">
|
||||
{t('modelManager.convertToDiffusersSaveLocation')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={saveLocation}
|
||||
onChange={(v) => setSaveLocation(v as SaveLocation)}
|
||||
>
|
||||
<RadioGroup value={saveLocation} onChange={handleChangeSaveLocation}>
|
||||
<Flex gap={4}>
|
||||
<Radio value="InvokeAIRoot">
|
||||
<Tooltip label="Save converted model in the InvokeAI root folder">
|
||||
@@ -162,9 +177,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
</Text>
|
||||
<IAIInput
|
||||
value={customSaveLocation}
|
||||
onChange={(e) => {
|
||||
setCustomSaveLocation(e.target.value);
|
||||
}}
|
||||
onChange={handleChangeCustomSaveLocation}
|
||||
width="full"
|
||||
/>
|
||||
</Flex>
|
||||
@@ -100,7 +100,7 @@ const ModelList = (props: ModelListProps) => {
|
||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => setModelFormatFilter('all')}
|
||||
onClick={setModelFormatFilter.bind(null, 'all')}
|
||||
isChecked={modelFormatFilter === 'all'}
|
||||
size="sm"
|
||||
>
|
||||
@@ -108,35 +108,35 @@ const ModelList = (props: ModelListProps) => {
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('diffusers')}
|
||||
onClick={setModelFormatFilter.bind(null, 'diffusers')}
|
||||
isChecked={modelFormatFilter === 'diffusers'}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('checkpoint')}
|
||||
onClick={setModelFormatFilter.bind(null, 'checkpoint')}
|
||||
isChecked={modelFormatFilter === 'checkpoint'}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('onnx')}
|
||||
onClick={setModelFormatFilter.bind(null, 'onnx')}
|
||||
isChecked={modelFormatFilter === 'onnx'}
|
||||
>
|
||||
{t('modelManager.onnxModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('olive')}
|
||||
onClick={setModelFormatFilter.bind(null, 'olive')}
|
||||
isChecked={modelFormatFilter === 'olive'}
|
||||
>
|
||||
{t('modelManager.oliveModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('lora')}
|
||||
onClick={setModelFormatFilter.bind(null, 'lora')}
|
||||
isChecked={modelFormatFilter === 'lora'}
|
||||
>
|
||||
{t('modelManager.loraModels')}
|
||||
@@ -4,6 +4,7 @@ import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSync } from 'react-icons/fa';
|
||||
import { useSyncModelsMutation } from 'services/api/endpoints/models';
|
||||
@@ -19,7 +20,7 @@ export default function SyncModelsButton(props: SyncModelsButtonProps) {
|
||||
|
||||
const [syncModels, { isLoading }] = useSyncModelsMutation();
|
||||
|
||||
const syncModelsHandler = () => {
|
||||
const syncModelsHandler = useCallback(() => {
|
||||
syncModels()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
@@ -44,7 +45,7 @@ export default function SyncModelsButton(props: SyncModelsButtonProps) {
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
}, [dispatch, syncModels, t]);
|
||||
|
||||
return !iconMode ? (
|
||||
<IAIButton
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, PropsWithChildren, memo } from 'react';
|
||||
import { useState, PropsWithChildren, memo, useCallback } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { Flex, Image, Text } from '@chakra-ui/react';
|
||||
@@ -59,13 +59,13 @@ export default memo(CurrentImageNode);
|
||||
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
const handleMouseEnter = () => {
|
||||
const handleMouseEnter = useCallback(() => {
|
||||
setIsHovering(true);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleMouseLeave = () => {
|
||||
const handleMouseLeave = useCallback(() => {
|
||||
setIsHovering(false);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<NodeWrapper
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user