Compare commits

..

4 Commits

Author SHA1 Message Date
psychedelicious
cd528eda32 test: fixt lint check 2023-11-13 11:03:56 +11:00
psychedelicious
4a27daa149 test: violate lint check 2023-11-13 11:03:09 +11:00
psychedelicious
9eafec720d test: fix format 2023-11-13 11:02:55 +11:00
psychedelicious
3d3775c962 test: violate style check 2023-11-13 11:01:32 +11:00
200 changed files with 3364 additions and 4331 deletions

View File

@@ -6,7 +6,7 @@ on:
branches: main
jobs:
ruff:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

View File

@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
_For Linux with an AMD GPU:_
@@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2/M3:_
_For Macintoshes, either Intel or M1/M2:_
```sh
pip install InvokeAI --use-pep517

File diff suppressed because it is too large Load Diff

View File

@@ -179,7 +179,7 @@ experimental versions later.
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
or CPU (no graphics acceleration). On Windows, you'll have the
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
graphics acceleration without installing additional drivers. If you
are unsure what GPU you are using, you can ask the installer to
guess.
@@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install xformers
```

View File

@@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@@ -327,7 +327,7 @@ installation protocol (important!)
=== "CUDA (NVidia)"
```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai
conda create -n invokeai python=3.10
conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web
```

View File

@@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/cu121` as described in the [Manual
https://download.pytorch.org/whl/cu118` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm

View File

@@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2/M3): You
Developers on Apple silicon (M1/M2): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your

View File

@@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands:
```sh
pip install xformers~=0.0.22
pip install xformers~=0.0.19
pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output
```
@@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
following:
```sh
xFormers 0.0.22
xFormers 0.0.20
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
@@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
is_functorch_available: False
pytorch.version: 2.1.0+cu121
pytorch.version: 2.0.1+cu118
pytorch.cuda: available
gpu.compute_capability: 8.9
gpu.name: NVIDIA GeForce RTX 4070
build.info: available
build.cuda_version: 1108
build.python_version: 3.10.11
build.torch_version: 2.1.0+cu121
build.torch_version: 2.0.1+cu118
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
@@ -92,22 +92,33 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe.
#### 1. Install CUDA Toolkit 12.1
#### 1. Install CUDA Toolkit 11.8
You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 12.1
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the
last step.
This will provide you with a recipe for downloading and running a
install shell script that will install the toolkit and drivers.
install shell script that will install the toolkit and drivers. For
example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is:
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
```
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date
installer for your system.
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries
@@ -122,7 +133,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")'
```
If it prints __2.1.0+cu121__ you're good. If not, you can install the
If it prints __1.13.1+cu118__ you're good. If not, you can install the
most up to date libraries with this command:
```sh

View File

@@ -244,7 +244,7 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.1.0",
"torch~=2.0.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"--force-reinstall",
@@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@@ -24,7 +24,6 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@@ -86,7 +85,6 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@@ -113,7 +111,6 @@ class ApiDependencies:
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@@ -1,164 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union

View File

@@ -43,7 +43,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
board_images,
boards,
images,
model_records,
models,
session_queue,
sessions,
@@ -107,7 +106,6 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")

View File

@@ -112,11 +112,10 @@ class CompelInvocation(BaseInvocation):
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@@ -235,11 +234,10 @@ class SDXLPromptInvocationBase:
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,

View File

@@ -22,7 +22,6 @@ if TYPE_CHECKING:
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@@ -50,7 +49,6 @@ class InvocationServices:
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@@ -78,7 +76,6 @@ class InvocationServices:
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@@ -104,7 +101,6 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@@ -1,8 +0,0 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,
ModelRecordServiceBase,
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401

View File

@@ -1,169 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated path."""
pass
@abstractmethod
def search_by_hash(
self,
hash: str,
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated original hash."""
pass
@abstractmethod
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
if len(model_configs) == 0:
raise UnknownModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@@ -1,397 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Implementation of the ModelRecordServiceBase API
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreSQL(sqlite_db)
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
type='embedding',
format='embedding_file',
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelConfigFactory,
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results

View File

@@ -1,323 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
)
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
return model

View File

@@ -1,66 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from imohash import hashfile
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@@ -1,93 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
def __init__(self):
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
try:
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
def main():
MigrateModelYamlToDb().migrate()
if __name__ == "__main__":
main()

View File

@@ -748,7 +748,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = scales * conditioning_scale
down_block_res_samples = [
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)
]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:

View File

@@ -5,7 +5,6 @@ import math
import multiprocessing as mp
import os
import re
import warnings
from collections import abc
from inspect import isfunction
from pathlib import Path
@@ -15,10 +14,8 @@ from threading import Thread
import numpy as np
import requests
import torch
from diffusers import logging as diffusers_logging
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
@@ -382,21 +379,3 @@ class Chdir(object):
def __exit__(self, *args):
os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -90,14 +90,6 @@ def get_extras():
pass
return extras
def get_extra_index() -> str:
# parsed_version.local for torch is the platform + version, eg 'cu121' or 'rocm5.6'
local = pkg_resources.get_distribution("torch").parsed_version.local
if local and 'cu' in local:
return "--extra-index-url https://download.pytorch.org/whl/cu121"
if local and 'rocm' in local:
return "--extra-index-url https://download.pytorch.org/whl/rocm5.6"
return ""
def main():
versions = get_versions()
@@ -130,15 +122,14 @@ def main():
branch = Prompt.ask("Enter an InvokeAI branch name")
extras = get_extras()
extra_index_url = get_extra_index()
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
if release:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade {extra_index_url}'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
elif tag:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade {extra_index_url}'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
else:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade {extra_index_url}'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
print("")
print("")
if os.system(cmd) == 0:

View File

@@ -24,7 +24,6 @@ module.exports = {
root: true,
rules: {
curly: 'error',
'react/jsx-no-bind': ['error', { allowBind: true }],
'react/jsx-curly-brace-presence': [
'error',
{ props: 'never', children: 'never' },

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,280 +0,0 @@
import{w as s,ia as T,v as l,a2 as I,ib as R,ae as V,ic as z,id as j,ie as D,ig as F,ih as G,ii as W,ij as K,aG as H,ik as U,il as Y}from"./index-54a1ea80.js";import{M as Z}from"./MantineProvider-17a58e64.js";var P=String.raw,E=P`
:root,
:host {
--chakra-vh: 100vh;
}
@supports (height: -webkit-fill-available) {
:root,
:host {
--chakra-vh: -webkit-fill-available;
}
}
@supports (height: -moz-fill-available) {
:root,
:host {
--chakra-vh: -moz-fill-available;
}
}
@supports (height: 100dvh) {
:root,
:host {
--chakra-vh: 100dvh;
}
}
`,B=()=>s.jsx(T,{styles:E}),J=({scope:e=""})=>s.jsx(T,{styles:P`
html {
line-height: 1.5;
-webkit-text-size-adjust: 100%;
font-family: system-ui, sans-serif;
-webkit-font-smoothing: antialiased;
text-rendering: optimizeLegibility;
-moz-osx-font-smoothing: grayscale;
touch-action: manipulation;
}
body {
position: relative;
min-height: 100%;
margin: 0;
font-feature-settings: "kern";
}
${e} :where(*, *::before, *::after) {
border-width: 0;
border-style: solid;
box-sizing: border-box;
word-wrap: break-word;
}
main {
display: block;
}
${e} hr {
border-top-width: 1px;
box-sizing: content-box;
height: 0;
overflow: visible;
}
${e} :where(pre, code, kbd,samp) {
font-family: SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 1em;
}
${e} a {
background-color: transparent;
color: inherit;
text-decoration: inherit;
}
${e} abbr[title] {
border-bottom: none;
text-decoration: underline;
-webkit-text-decoration: underline dotted;
text-decoration: underline dotted;
}
${e} :where(b, strong) {
font-weight: bold;
}
${e} small {
font-size: 80%;
}
${e} :where(sub,sup) {
font-size: 75%;
line-height: 0;
position: relative;
vertical-align: baseline;
}
${e} sub {
bottom: -0.25em;
}
${e} sup {
top: -0.5em;
}
${e} img {
border-style: none;
}
${e} :where(button, input, optgroup, select, textarea) {
font-family: inherit;
font-size: 100%;
line-height: 1.15;
margin: 0;
}
${e} :where(button, input) {
overflow: visible;
}
${e} :where(button, select) {
text-transform: none;
}
${e} :where(
button::-moz-focus-inner,
[type="button"]::-moz-focus-inner,
[type="reset"]::-moz-focus-inner,
[type="submit"]::-moz-focus-inner
) {
border-style: none;
padding: 0;
}
${e} fieldset {
padding: 0.35em 0.75em 0.625em;
}
${e} legend {
box-sizing: border-box;
color: inherit;
display: table;
max-width: 100%;
padding: 0;
white-space: normal;
}
${e} progress {
vertical-align: baseline;
}
${e} textarea {
overflow: auto;
}
${e} :where([type="checkbox"], [type="radio"]) {
box-sizing: border-box;
padding: 0;
}
${e} input[type="number"]::-webkit-inner-spin-button,
${e} input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none !important;
}
${e} input[type="number"] {
-moz-appearance: textfield;
}
${e} input[type="search"] {
-webkit-appearance: textfield;
outline-offset: -2px;
}
${e} input[type="search"]::-webkit-search-decoration {
-webkit-appearance: none !important;
}
${e} ::-webkit-file-upload-button {
-webkit-appearance: button;
font: inherit;
}
${e} details {
display: block;
}
${e} summary {
display: list-item;
}
template {
display: none;
}
[hidden] {
display: none !important;
}
${e} :where(
blockquote,
dl,
dd,
h1,
h2,
h3,
h4,
h5,
h6,
hr,
figure,
p,
pre
) {
margin: 0;
}
${e} button {
background: transparent;
padding: 0;
}
${e} fieldset {
margin: 0;
padding: 0;
}
${e} :where(ol, ul) {
margin: 0;
padding: 0;
}
${e} textarea {
resize: vertical;
}
${e} :where(button, [role="button"]) {
cursor: pointer;
}
${e} button::-moz-focus-inner {
border: 0 !important;
}
${e} table {
border-collapse: collapse;
}
${e} :where(h1, h2, h3, h4, h5, h6) {
font-size: inherit;
font-weight: inherit;
}
${e} :where(button, input, optgroup, select, textarea) {
padding: 0;
line-height: inherit;
color: inherit;
}
${e} :where(img, svg, video, canvas, audio, iframe, embed, object) {
display: block;
}
${e} :where(img, video) {
max-width: 100%;
height: auto;
}
[data-js-focus-visible]
:focus:not([data-focus-visible-added]):not(
[data-focus-visible-disabled]
) {
outline: none;
box-shadow: none;
}
${e} select::-ms-expand {
display: none;
}
${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};

File diff suppressed because one or more lines are too long

View File

@@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import queueReducer from 'features/queue/store/queueSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicMiddlewares from 'redux-dynamic-middlewares';

View File

@@ -8,14 +8,7 @@ import {
forwardRef,
useDisclosure,
} from '@chakra-ui/react';
import {
cloneElement,
memo,
ReactElement,
ReactNode,
useCallback,
useRef,
} from 'react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
@@ -45,15 +38,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const handleAccept = useCallback(() => {
const handleAccept = () => {
acceptCallback();
onClose();
}, [acceptCallback, onClose]);
};
const handleCancel = useCallback(() => {
const handleCancel = () => {
cancelCallback && cancelCallback();
onClose();
}, [cancelCallback, onClose]);
};
return (
<>

View File

@@ -0,0 +1,43 @@
import { Box, Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaExclamation } from 'react-icons/fa';
const IAIErrorLoadingImageFallback = () => {
return (
<Box
sx={{
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
height: 'full',
width: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
bg: 'base.100',
color: 'base.500',
_dark: {
color: 'base.700',
bg: 'base.850',
},
}}
>
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
</Flex>
</Box>
);
};
export default memo(IAIErrorLoadingImageFallback);

View File

@@ -0,0 +1,8 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@@ -0,0 +1,15 @@
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormErrorMessageProps = FormErrorMessageProps & {
children: ReactNode | string;
};
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
const { children, ...rest } = props;
return (
<FormErrorMessage color="error.400" {...rest}>
{children}
</FormErrorMessage>
);
}

View File

@@ -0,0 +1,15 @@
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormHelperTextProps = FormHelperTextProps & {
children: ReactNode | string;
};
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
const { children, ...rest } = props;
return (
<FormHelperText margin={0} color="base.400" {...rest}>
{children}
</FormHelperText>
);
}

View File

@@ -0,0 +1,25 @@
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}
</Flex>
);
}

View File

@@ -0,0 +1,25 @@
import {
Checkbox,
CheckboxProps,
FormControl,
FormControlProps,
FormLabel,
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIFullCheckboxProps = CheckboxProps & {
label: string | ReactNode;
formControlProps?: FormControlProps;
};
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
const { label, formControlProps, ...rest } = props;
return (
<FormControl {...formControlProps}>
<FormLabel>{label}</FormLabel>
<Checkbox colorScheme="accent" {...rest} />
</FormControl>
);
};
export default memo(IAIFullCheckbox);

View File

@@ -1,7 +1,6 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
@@ -21,37 +20,26 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const stylesFunc = useCallback(
() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal' as const,
marginBottom: 4,
},
}),
[
accent300,
accent500,
base100,
base200,
base300,
base50,
base700,
base800,
base900,
colorMode,
]
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
marginBottom: 4,
},
})}
{...rest}
/>
);
return <TextInput styles={stylesFunc} {...rest} />;
}

View File

@@ -98,34 +98,28 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
}
}, [value, valueAsString]);
const handleOnChange = useCallback(
(v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
},
[isInteger, onChange]
);
const handleOnChange = (v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
};
/**
* Clicking the steppers allows the value to go outside bounds; we need to
* clamp it on blur and floor it if needed.
*/
const handleBlur = useCallback(
(e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
},
[isInteger, max, min, onChange]
);
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
};
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {

View File

@@ -6,7 +6,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { memo, MouseEvent, useCallback } from 'react';
import { memo, MouseEvent } from 'react';
import IAIOption from './IAIOption';
type IAISelectProps = SelectProps & {
@@ -33,16 +33,15 @@ const IAISelect = (props: IAISelectProps) => {
spaceEvenly,
...rest
} = props;
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}, []);
return (
<FormControl
isDisabled={isDisabled}
onClick={handleClick}
onClick={(e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}}
sx={
horizontal
? {

View File

@@ -186,13 +186,6 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
[dispatch]
);
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
const handleStepperClick = useCallback(
() => onChange(Number(localInputValue)),
[localInputValue, onChange]
);
return (
<FormControl
ref={ref}
@@ -226,8 +219,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
max={max}
step={step}
onChange={handleSliderChange}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
onMouseEnter={() => setShowTooltip(true)}
onMouseLeave={() => setShowTooltip(false)}
focusThumbOnChange={false}
isDisabled={isDisabled}
{...rest}
@@ -339,8 +332,12 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
{...sliderNumberInputFieldProps}
/>
<NumberInputStepper {...sliderNumberInputStepperProps}>
<NumberIncrementStepper onClick={handleStepperClick} />
<NumberDecrementStepper onClick={handleStepperClick} />
<NumberIncrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
<NumberDecrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
</NumberInputStepper>
</NumberInput>
)}

View File

@@ -146,15 +146,16 @@ const ImageUploader = (props: ImageUploaderProps) => {
};
}, [inputRef]);
const handleKeyDown = useCallback((e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}, []);
return (
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
<Box
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}}
>
<input {...getInputProps()} />
{children}
<AnimatePresence>

View File

@@ -0,0 +1,23 @@
import { Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
// bg: 'base.800',
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',
}}
>
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
</Flex>
);
};
export default memo(SelectImagePlaceholder);

View File

@@ -0,0 +1,24 @@
import { useBreakpoint } from '@chakra-ui/react';
export default function useResolution():
| 'mobile'
| 'tablet'
| 'desktop'
| 'unknown' {
const breakpointValue = useBreakpoint();
const mobileResolutions = ['base', 'sm'];
const tabletResolutions = ['md', 'lg'];
const desktopResolutions = ['xl', '2xl'];
if (mobileResolutions.includes(breakpointValue)) {
return 'mobile';
}
if (tabletResolutions.includes(breakpointValue)) {
return 'tablet';
}
if (desktopResolutions.includes(breakpointValue)) {
return 'desktop';
}
return 'unknown';
}

View File

@@ -0,0 +1,7 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@@ -0,0 +1,71 @@
// TODO: Restore variations
// Support code from v2.3 in here.
// export const stringToSeedWeights = (
// string: string
// ): InvokeAI.SeedWeights | boolean => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
// return { seed: Number(p[0]), weight: Number(p[1]) };
// });
// if (!validateSeedWeights(pairs)) {
// return false;
// }
// return pairs;
// };
// export const validateSeedWeights = (
// seedWeights: InvokeAI.SeedWeights | string
// ): boolean => {
// return typeof seedWeights === 'string'
// ? Boolean(stringToSeedWeights(seedWeights))
// : Boolean(
// seedWeights.length &&
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
// const { seed, weight } = pair;
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
// const isWeightValid =
// !isNaN(parseInt(weight.toString(), 10)) &&
// weight >= 0 &&
// weight <= 1;
// return !(isSeedValid && isWeightValid);
// })
// );
// };
// export const seedWeightsToString = (
// seedWeights: InvokeAI.SeedWeights
// ): string => {
// return seedWeights.reduce((acc, pair, i, arr) => {
// const { seed, weight } = pair;
// acc += `${seed}:${weight}`;
// if (i !== arr.length - 1) {
// acc += ',';
// }
// return acc;
// }, '');
// };
// export const seedWeightsToArray = (
// seedWeights: InvokeAI.SeedWeights
// ): Array<Array<number>> => {
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
// pair.seed,
// pair.weight,
// ]);
// };
// export const stringToSeedWeightsArray = (
// string: string
// ): Array<Array<number>> => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// return arrPairs.map(
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
// );
// };
export default {};

View File

@@ -5,22 +5,17 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { isStagingSelector } from '../store/canvasSelectors';
import { memo, useCallback } from 'react';
import { memo } from 'react';
const ClearCanvasHistoryButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const acceptCallback = useCallback(
() => dispatch(clearCanvasHistory()),
[dispatch]
);
return (
<IAIAlertDialog
title={t('unifiedCanvas.clearCanvasHistory')}
acceptCallback={acceptCallback}
acceptCallback={() => dispatch(clearCanvasHistory())}
acceptButtonText={t('unifiedCanvas.clearHistory')}
triggerComponent={
<IAIButton size="sm" leftIcon={<FaTrash />} isDisabled={isStaging}>

View File

@@ -20,8 +20,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo, useCallback } from 'react';
import { RgbaColor } from 'react-colorful';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -96,35 +95,18 @@ const IAICanvasMaskOptions = () => {
[isMaskEnabled]
);
const handleToggleMaskLayer = useCallback(() => {
const handleToggleMaskLayer = () => {
dispatch(setLayer(layer === 'mask' ? 'base' : 'mask'));
}, [dispatch, layer]);
};
const handleClearMask = useCallback(() => {
dispatch(clearMask());
}, [dispatch]);
const handleClearMask = () => dispatch(clearMask());
const handleToggleEnableMask = useCallback(() => {
const handleToggleEnableMask = () =>
dispatch(setIsMaskEnabled(!isMaskEnabled));
}, [dispatch, isMaskEnabled]);
const handleSaveMask = useCallback(async () => {
const handleSaveMask = async () => {
dispatch(canvasMaskSavedToGallery());
}, [dispatch]);
const handleChangePreserveMaskedArea = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldPreserveMaskedArea(e.target.checked));
},
[dispatch]
);
const handleChangeMaskColor = useCallback(
(newColor: RgbaColor) => {
dispatch(setMaskColor(newColor));
},
[dispatch]
);
};
return (
<IAIPopover
@@ -149,10 +131,15 @@ const IAICanvasMaskOptions = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.preserveMaskedArea')}
isChecked={shouldPreserveMaskedArea}
onChange={handleChangePreserveMaskedArea}
onChange={(e) =>
dispatch(setShouldPreserveMaskedArea(e.target.checked))
}
/>
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
<IAIColorPicker color={maskColor} onChange={handleChangeMaskColor} />
<IAIColorPicker
color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
</Box>
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
Save Mask

View File

@@ -10,7 +10,6 @@ import { redo } from 'features/canvas/store/canvasSlice';
import { stateSelector } from 'app/store/store';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { useCallback } from 'react';
const canvasRedoSelector = createSelector(
[stateSelector, activeTabNameSelector],
@@ -35,9 +34,9 @@ export default function IAICanvasRedoButton() {
const { t } = useTranslation();
const handleRedo = useCallback(() => {
const handleRedo = () => {
dispatch(redo());
}, [dispatch]);
};
useHotkeys(
['meta+shift+z', 'ctrl+shift+z', 'control+y', 'meta+y'],

View File

@@ -18,7 +18,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo, useCallback } from 'react';
import { ChangeEvent, memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaWrench } from 'react-icons/fa';
@@ -86,52 +86,8 @@ const IAICanvasSettingsButtonPopover = () => {
[shouldSnapToGrid]
);
const handleChangeShouldSnapToGrid = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldSnapToGrid(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowIntermediates = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowIntermediates(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowGrid = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowGrid(e.target.checked)),
[dispatch]
);
const handleChangeShouldDarkenOutsideBoundingBox = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldAutoSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAutoSave(e.target.checked)),
[dispatch]
);
const handleChangeShouldCropToBoundingBoxOnSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked)),
[dispatch]
);
const handleChangeShouldRestrictStrokesToBox = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRestrictStrokesToBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowCanvasDebugInfo = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowCanvasDebugInfo(e.target.checked)),
[dispatch]
);
const handleChangeShouldAntialias = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAntialias(e.target.checked)),
[dispatch]
);
const handleChangeShouldSnapToGrid = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldSnapToGrid(e.target.checked));
return (
<IAIPopover
@@ -148,12 +104,14 @@ const IAICanvasSettingsButtonPopover = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.showIntermediates')}
isChecked={shouldShowIntermediates}
onChange={handleChangeShouldShowIntermediates}
onChange={(e) =>
dispatch(setShouldShowIntermediates(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.showGrid')}
isChecked={shouldShowGrid}
onChange={handleChangeShouldShowGrid}
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.snapToGrid')}
@@ -163,33 +121,41 @@ const IAICanvasSettingsButtonPopover = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.darkenOutsideSelection')}
isChecked={shouldDarkenOutsideBoundingBox}
onChange={handleChangeShouldDarkenOutsideBoundingBox}
onChange={(e) =>
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.autoSaveToGallery')}
isChecked={shouldAutoSave}
onChange={handleChangeShouldAutoSave}
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.saveBoxRegionOnly')}
isChecked={shouldCropToBoundingBoxOnSave}
onChange={handleChangeShouldCropToBoundingBoxOnSave}
onChange={(e) =>
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.limitStrokesToBox')}
isChecked={shouldRestrictStrokesToBox}
onChange={handleChangeShouldRestrictStrokesToBox}
onChange={(e) =>
dispatch(setShouldRestrictStrokesToBox(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.showCanvasDebugInfo')}
isChecked={shouldShowCanvasDebugInfo}
onChange={handleChangeShouldShowCanvasDebugInfo}
onChange={(e) =>
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.antialiasing')}
isChecked={shouldAntialias}
onChange={handleChangeShouldAntialias}
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/>
<ClearCanvasHistoryButtonModal />
</Flex>

View File

@@ -15,8 +15,7 @@ import {
setTool,
} from 'features/canvas/store/canvasSlice';
import { clamp, isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { RgbaColor } from 'react-colorful';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -173,33 +172,11 @@ const IAICanvasToolChooserOptions = () => {
[brushColor]
);
const handleSelectBrushTool = useCallback(() => {
dispatch(setTool('brush'));
}, [dispatch]);
const handleSelectEraserTool = useCallback(() => {
dispatch(setTool('eraser'));
}, [dispatch]);
const handleSelectColorPickerTool = useCallback(() => {
dispatch(setTool('colorPicker'));
}, [dispatch]);
const handleFillRect = useCallback(() => {
dispatch(addFillRect());
}, [dispatch]);
const handleEraseBoundingBox = useCallback(() => {
dispatch(addEraseRect());
}, [dispatch]);
const handleChangeBrushSize = useCallback(
(newSize: number) => {
dispatch(setBrushSize(newSize));
},
[dispatch]
);
const handleChangeBrushColor = useCallback(
(newColor: RgbaColor) => {
dispatch(setBrushColor(newColor));
},
[dispatch]
);
const handleSelectBrushTool = () => dispatch(setTool('brush'));
const handleSelectEraserTool = () => dispatch(setTool('eraser'));
const handleSelectColorPickerTool = () => dispatch(setTool('colorPicker'));
const handleFillRect = () => dispatch(addFillRect());
const handleEraseBoundingBox = () => dispatch(addEraseRect());
return (
<ButtonGroup isAttached>
@@ -256,7 +233,7 @@ const IAICanvasToolChooserOptions = () => {
label={t('unifiedCanvas.brushSize')}
value={brushSize}
withInput
onChange={handleChangeBrushSize}
onChange={(newSize) => dispatch(setBrushSize(newSize))}
sliderNumberInputProps={{ max: 500 }}
/>
</Flex>
@@ -270,7 +247,7 @@ const IAICanvasToolChooserOptions = () => {
<IAIColorPicker
withNumberInput={true}
color={brushColor}
onChange={handleChangeBrushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
</Box>
</Flex>

View File

@@ -25,9 +25,9 @@ import {
LAYER_NAMES_DICT,
} from 'features/canvas/store/canvasTypes';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@@ -151,9 +151,7 @@ const IAICanvasToolbar = () => {
[canvasBaseLayer]
);
const handleSelectMoveTool = useCallback(() => {
dispatch(setTool('move'));
}, [dispatch]);
const handleSelectMoveTool = () => dispatch(setTool('move'));
const handleClickResetCanvasView = useSingleAndDoubleClick(
() => handleResetCanvasView(false),
@@ -176,39 +174,36 @@ const IAICanvasToolbar = () => {
);
};
const handleResetCanvas = useCallback(() => {
const handleResetCanvas = () => {
dispatch(resetCanvas());
}, [dispatch]);
};
const handleMergeVisible = useCallback(() => {
const handleMergeVisible = () => {
dispatch(canvasMerged());
}, [dispatch]);
};
const handleSaveToGallery = useCallback(() => {
const handleSaveToGallery = () => {
dispatch(canvasSavedToGallery());
}, [dispatch]);
};
const handleCopyImageToClipboard = useCallback(() => {
const handleCopyImageToClipboard = () => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard());
}, [dispatch, isClipboardAPIAvailable]);
};
const handleDownloadAsImage = useCallback(() => {
const handleDownloadAsImage = () => {
dispatch(canvasDownloadedAsImage());
}, [dispatch]);
};
const handleChangeLayer = useCallback(
(v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
}
},
[dispatch, isMaskEnabled]
);
const handleChangeLayer = (v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
}
};
return (
<Flex

View File

@@ -10,7 +10,6 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { stateSelector } from 'app/store/store';
import { useCallback } from 'react';
const canvasUndoSelector = createSelector(
[stateSelector, activeTabNameSelector],
@@ -36,9 +35,9 @@ export default function IAICanvasUndoButton() {
const { canUndo, activeTabName } = useAppSelector(canvasUndoSelector);
const handleUndo = useCallback(() => {
const handleUndo = () => {
dispatch(undo());
}, [dispatch]);
};
useHotkeys(
['meta+z', 'ctrl+z'],

View File

@@ -0,0 +1,16 @@
import Konva from 'konva';
import { IRect } from 'konva/lib/types';
/**
* Converts a Konva node to a dataURL
* @param node - The Konva node to convert to a dataURL
* @param boundingBox - The bounding box to crop to
* @returns A dataURL of the node cropped to the bounding box
*/
export const konvaNodeToDataURL = (
node: Konva.Node,
boundingBox: IRect
): string => {
// get a dataURL of the bbox'd region
return node.toDataURL(boundingBox);
};

View File

@@ -87,11 +87,6 @@ const ChangeBoardModal = () => {
selectedBoard,
]);
const handleSetSelectedBoard = useCallback(
(v: string | null) => setSelectedBoard(v),
[]
);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
@@ -118,7 +113,7 @@ const ChangeBoardModal = () => {
isFetching ? t('boards.loading') : t('boards.selectBoard')
}
disabled={isFetching}
onChange={handleSetSelectedBoard}
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={data}
/>

View File

@@ -0,0 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
type Props = {
id: string;
};
const ControlAdapterPreprocessButton = ({ id }: Props) => {
const controlImage = useControlAdapterControlImage(id);
const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => {
dispatch(
controlAdapterImageProcessed({
id,
})
);
}, [id, dispatch]);
return (
<IAIButton
size="sm"
onClick={handleProcess}
isDisabled={Boolean(!controlImage) || !isReady}
>
Preprocess
</IAIButton>
);
};
export default memo(ControlAdapterPreprocessButton);

View File

@@ -14,9 +14,9 @@ import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectI
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { useTranslation } from 'react-i18next';
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
@@ -78,13 +78,6 @@ const ParamEmbeddingPopover = (props: Props) => {
[onSelect]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
return (
<Popover
initialFocusRef={inputRef}
@@ -134,7 +127,12 @@ const ParamEmbeddingPopover = (props: Props) => {
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
onDropdownClose={onClose}
filter={filterFunc}
filter={(value, item: SelectItem) =>
item.label
?.toLowerCase()
.includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
)}

View File

@@ -60,13 +60,6 @@ const BoardAutoAddSelect = () => {
[dispatch]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
return (
<IAIMantineSearchableSelect
label={t('boards.autoAddBoard')}
@@ -78,7 +71,10 @@ const BoardAutoAddSelect = () => {
nothingFound={t('boards.noMatching')}
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={!hasBoards || autoAssignBoardOnClick}
filter={filterFunc}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
);

View File

@@ -90,50 +90,6 @@ const BoardContextMenu = ({
e.preventDefault();
}, []);
const renderMenuFunc = useCallback(
() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup title={boardName}>
<MenuItem
icon={<FaPlus />}
isDisabled={isAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{isBulkDownloadEnabled && (
<MenuItem icon={<FaDownload />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
{!board && <NoBoardContextMenuItems />}
{board && (
<GalleryBoardContextMenuItems
board={board}
setBoardToDelete={setBoardToDelete}
/>
)}
</MenuGroup>
</MenuList>
),
[
autoAssignBoardOnClick,
board,
boardName,
handleBulkDownload,
handleSetAutoAdd,
isAutoAdd,
isBulkDownloadEnabled,
setBoardToDelete,
skipEvent,
t,
]
);
return (
<IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
@@ -141,7 +97,38 @@ const BoardContextMenu = ({
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={renderMenuFunc}
renderMenu={() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup title={boardName}>
<MenuItem
icon={<FaPlus />}
isDisabled={isAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{isBulkDownloadEnabled && (
<MenuItem
icon={<FaDownload />}
onClickCapture={handleBulkDownload}
>
{t('boards.downloadBoard')}
</MenuItem>
)}
{!board && <NoBoardContextMenuItems />}
{board && (
<GalleryBoardContextMenuItems
board={board}
setBoardToDelete={setBoardToDelete}
/>
)}
</MenuGroup>
</MenuList>
)}
>
{children}
</IAIContextMenu>

View File

@@ -0,0 +1,108 @@
import { As, Badge, Flex } from '@chakra-ui/react';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types';
import { ReactNode, memo } from 'react';
import BoardContextMenu from '../BoardContextMenu';
type GenericBoardProps = {
board_id: BoardId;
droppableData?: TypesafeDroppableData;
onClick: () => void;
isSelected: boolean;
icon: As;
label: string;
dropLabel?: ReactNode;
badgeCount?: number;
};
export const formatBadgeCount = (count: number) =>
Intl.NumberFormat('en-US', {
notation: 'compact',
maximumFractionDigits: 1,
}).format(count);
const GenericBoard = (props: GenericBoardProps) => {
const {
board_id,
droppableData,
onClick,
isSelected,
icon,
label,
badgeCount,
dropLabel,
} = props;
return (
<BoardContextMenu board_id={board_id}>
{(ref) => (
<Flex
ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
>
<Flex
onClick={onClick}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}}
>
<IAINoContentFallback
boxSize={8}
icon={icon}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
{badgeCount !== undefined && (
<Badge variant="solid">{formatBadgeCount(badgeCount)}</Badge>
)}
</Flex>
<IAIDroppable data={droppableData} dropLabel={dropLabel} />
</Flex>
<Flex
sx={{
h: 'full',
alignItems: 'center',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'sm',
color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' },
}}
>
{label}
</Flex>
</Flex>
)}
</BoardContextMenu>
);
};
export default memo(GenericBoard);

View File

@@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
board_id: 'images' | 'assets' | 'no_board';
};
const SystemBoardButton = ({ board_id }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const { selectedBoardId } = gallery;
return { isSelected: selectedBoardId === board_id };
},
defaultSelectorOptions
),
[board_id]
);
const { isSelected } = useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleClick = useCallback(() => {
dispatch(boardIdSelected({ boardId: board_id }));
}, [board_id, dispatch]);
return (
<IAIButton
onClick={handleClick}
size="sm"
isChecked={isSelected}
sx={{
flexGrow: 1,
borderRadius: 'base',
}}
>
{boardName}
</IAIButton>
);
};
export default memo(SystemBoardButton);

View File

@@ -0,0 +1,22 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import { FaEyeSlash } from 'react-icons/fa';
const CurrentImageHidden = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<FaEyeSlash fontSize="25vh" />
</Flex>
);
};
export default memo(CurrentImageHidden);

View File

@@ -61,12 +61,6 @@ const GallerySettingsPopover = () => {
[dispatch]
);
const handleChangeAutoAssignBoardOnClick = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
return (
<IAIPopover
triggerComponent={
@@ -97,7 +91,9 @@ const GallerySettingsPopover = () => {
<IAISimpleCheckbox
label={t('gallery.autoAssignBoardOnClick')}
isChecked={autoAssignBoardOnClick}
onChange={handleChangeAutoAssignBoardOnClick}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(autoAssignBoardOnClickChanged(e.target.checked))
}
/>
<BoardAutoAddSelect />
</Flex>

View File

@@ -35,34 +35,6 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
e.preventDefault();
}, []);
const renderMenuFunc = useCallback(() => {
if (!imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MultipleSelectionMenuItems />
</MenuList>
);
}
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<SingleSelectionMenuItems imageDTO={imageDTO} />
</MenuList>
);
}, [imageDTO, selectionCount, skipEvent]);
return (
<IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
@@ -70,7 +42,33 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={renderMenuFunc}
renderMenu={() => {
if (!imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MultipleSelectionMenuItems />
</MenuList>
);
}
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<SingleSelectionMenuItems imageDTO={imageDTO} />
</MenuList>
);
}}
>
{children}
</IAIContextMenu>

View File

@@ -13,7 +13,7 @@ import { workflowLoadRequested } from 'features/nodes/store/actions';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { flushSync } from 'react-dom';

View File

@@ -0,0 +1,27 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
import { memo } from 'react';
type ImageFallbackSpinnerProps = SpinnerProps;
const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
minH: 36,
minW: 36,
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default memo(ImageFallbackSpinner);

View File

@@ -20,7 +20,6 @@ import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
import GalleryImage from './GalleryImage';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
import { EntityId } from '@reduxjs/toolkit';
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
defer: true,
@@ -72,13 +71,6 @@ const GalleryImageGrid = () => {
});
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
const itemContentFunc = useCallback(
(index: number, imageName: EntityId) => (
<GalleryImage key={imageName} imageName={imageName as string} />
),
[]
);
useEffect(() => {
// Initialize the gallery's custom scrollbar
const { current: root } = rootRef;
@@ -139,7 +131,9 @@ const GalleryImageGrid = () => {
List: ImageGridListContainer,
}}
scrollerRef={setScroller}
itemContent={itemContentFunc}
itemContent={(index, imageName) => (
<GalleryImage key={imageName} imageName={imageName as string} />
)}
/>
</Box>
<IAIButton

View File

@@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)}
onClick={() => handleRecallLoRA(lora)}
/>
);
}
@@ -289,7 +289,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)}
onClick={() => handleRecallControlNet(controlnet)}
/>
))}
{validIPAdapters.map((ipAdapter, index) => (
@@ -297,7 +297,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
onClick={() => handleRecallIPAdapter(ipAdapter)}
/>
))}
{validT2IAdapters.map((t2iAdapter, index) => (
@@ -305,7 +305,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
onClick={() => handleRecallT2IAdapter(t2iAdapter)}
/>
))}
</>

View File

@@ -1,6 +1,6 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
@@ -27,11 +27,6 @@ const ImageMetadataItem = ({
}: MetadataItemProps) => {
const { t } = useTranslation();
const handleCopy = useCallback(
() => navigator.clipboard.writeText(value.toString()),
[value]
);
if (!value) {
return null;
}
@@ -58,7 +53,7 @@ const ImageMetadataItem = ({
size="xs"
variant="ghost"
fontSize={14}
onClick={handleCopy}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}

View File

@@ -76,13 +76,6 @@ const ParamLoRASelect = () => {
[dispatch, loraModels?.entities]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
if (loraModels?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
@@ -101,7 +94,10 @@ const ParamLoRASelect = () => {
nothingFound="No matching LoRAs"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={filterFunc}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
data-testid="add-lora"
/>

View File

@@ -1,4 +1,4 @@
import { useState, PropsWithChildren, memo, useCallback } from 'react';
import { useState, PropsWithChildren, memo } from 'react';
import { useSelector } from 'react-redux';
import { createSelector } from '@reduxjs/toolkit';
import { Flex, Image, Text } from '@chakra-ui/react';
@@ -59,13 +59,13 @@ export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
const [isHovering, setIsHovering] = useState(false);
const handleMouseEnter = useCallback(() => {
const handleMouseEnter = () => {
setIsHovering(true);
}, []);
};
const handleMouseLeave = useCallback(() => {
const handleMouseLeave = () => {
setIsHovering(false);
}, []);
};
return (
<NodeWrapper

View File

@@ -104,24 +104,6 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
nodeId,
]);
const renderMenuFunc = useCallback(
() =>
!menuItems.length ? null : (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup
title={label || fieldTemplateTitle || t('nodes.unknownField')}
>
{menuItems}
</MenuGroup>
</MenuList>
),
[fieldTemplateTitle, label, menuItems, skipEvent, t]
);
return (
<IAIContextMenu<HTMLDivElement>
menuProps={{
@@ -132,7 +114,21 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={renderMenuFunc}
renderMenu={() =>
!menuItems.length ? null : (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup
title={label || fieldTemplateTitle || t('nodes.unknownField')}
>
{menuItems}
</MenuGroup>
</MenuList>
)
}
>
{children}
</IAIContextMenu>

View File

@@ -0,0 +1,14 @@
import {
ClipInputFieldTemplate,
ClipInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const ClipInputFieldComponent = (
_props: FieldComponentProps<ClipInputFieldValue, ClipInputFieldTemplate>
) => {
return null;
};
export default memo(ClipInputFieldComponent);

View File

@@ -0,0 +1,17 @@
import {
CollectionInputFieldTemplate,
CollectionInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const CollectionInputFieldComponent = (
_props: FieldComponentProps<
CollectionInputFieldValue,
CollectionInputFieldTemplate
>
) => {
return null;
};
export default memo(CollectionInputFieldComponent);

View File

@@ -0,0 +1,17 @@
import {
CollectionItemInputFieldTemplate,
CollectionItemInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const CollectionItemInputFieldComponent = (
_props: FieldComponentProps<
CollectionItemInputFieldValue,
CollectionItemInputFieldTemplate
>
) => {
return null;
};
export default memo(CollectionItemInputFieldComponent);

View File

@@ -0,0 +1,17 @@
import {
ConditioningInputFieldTemplate,
ConditioningInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const ConditioningInputFieldComponent = (
_props: FieldComponentProps<
ConditioningInputFieldValue,
ConditioningInputFieldTemplate
>
) => {
return null;
};
export default memo(ConditioningInputFieldComponent);

View File

@@ -0,0 +1,19 @@
import {
ControlInputFieldTemplate,
ControlInputFieldValue,
ControlPolymorphicInputFieldTemplate,
ControlPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const ControlInputFieldComponent = (
_props: FieldComponentProps<
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(ControlInputFieldComponent);

View File

@@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(DenoiseMaskInputFieldComponent);

View File

@@ -0,0 +1,19 @@
import {
IPAdapterInputFieldTemplate,
IPAdapterInputFieldValue,
FieldComponentProps,
IPAdapterPolymorphicInputFieldValue,
IPAdapterPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
const IPAdapterInputFieldComponent = (
_props: FieldComponentProps<
IPAdapterInputFieldValue | IPAdapterPolymorphicInputFieldValue,
IPAdapterInputFieldTemplate | IPAdapterPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(IPAdapterInputFieldComponent);

View File

@@ -0,0 +1,94 @@
import {
ImageCollectionInputFieldTemplate,
ImageCollectionInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { Flex } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { NodesMultiImageDropData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const ImageCollectionInputFieldComponent = (
props: FieldComponentProps<
ImageCollectionInputFieldValue,
ImageCollectionInputFieldTemplate
>
) => {
const { nodeId, field } = props;
// const dispatch = useAppDispatch();
// const handleDrop = useCallback(
// ({ image_name }: ImageDTO) => {
// dispatch(
// fieldValueChanged({
// nodeId,
// fieldName: field.name,
// value: uniqBy([...(field.value ?? []), { image_name }], 'image_name'),
// })
// );
// },
// [dispatch, field.name, field.value, nodeId]
// );
const droppableData: NodesMultiImageDropData = {
id: `node-${nodeId}-${field.name}`,
actionType: 'SET_MULTI_NODES_IMAGE',
context: { nodeId: nodeId, fieldName: field.name },
};
const {
isOver,
setNodeRef: setDroppableRef,
active,
} = useDroppableTypesafe({
id: `node_${nodeId}`,
data: droppableData,
});
// const handleReset = useCallback(() => {
// dispatch(
// fieldValueChanged({
// nodeId,
// fieldName: field.name,
// value: undefined,
// })
// );
// }, [dispatch, field.name, nodeId]);
return (
<Flex
ref={setDroppableRef}
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minH: '10rem',
}}
>
{field.value?.map(({ image_name }) => (
<ImageSubField key={image_name} imageName={image_name} />
))}
{isValidDrop(droppableData, active) && <IAIDropOverlay isOver={isOver} />}
</Flex>
);
};
export default memo(ImageCollectionInputFieldComponent);
type ImageSubFieldProps = { imageName: string };
const ImageSubField = (props: ImageSubFieldProps) => {
const { currentData: image } = useGetImageDTOQuery(props.imageName);
return (
<IAIDndImage imageDTO={image} isDropDisabled={true} isDragDisabled={true} />
);
};

View File

@@ -0,0 +1,19 @@
import {
LatentsInputFieldTemplate,
LatentsInputFieldValue,
FieldComponentProps,
LatentsPolymorphicInputFieldValue,
LatentsPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
const LatentsInputFieldComponent = (
_props: FieldComponentProps<
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(LatentsInputFieldComponent);

View File

@@ -80,13 +80,6 @@ const LoRAModelInputFieldComponent = (
[dispatch, field.name, nodeId]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
if (loraModels?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
@@ -108,7 +101,10 @@ const LoRAModelInputFieldComponent = (
nothingFound={t('models.noMatchingLoRAs')}
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={filterFunc}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
error={!selectedLoRAModel}
onChange={handleChange}
sx={{

View File

@@ -11,7 +11,7 @@ import {
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/modelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';

View File

@@ -19,7 +19,7 @@ import {
IntegerPolymorphicInputFieldTemplate,
IntegerPolymorphicInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
const NumberInputFieldComponent = (
props: FieldComponentProps<
@@ -43,23 +43,20 @@ const NumberInputFieldComponent = (
[fieldTemplate.type]
);
const handleValueChanged = useCallback(
(v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
dispatch(
fieldNumberValueChanged({
nodeId,
fieldName: field.name,
value: isIntegerField ? Math.floor(Number(v)) : Number(v),
})
);
}
},
[dispatch, field.name, isIntegerField, nodeId]
);
const handleValueChanged = (v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
dispatch(
fieldNumberValueChanged({
nodeId,
fieldName: field.name,
value: isIntegerField ? Math.floor(Number(v)) : Number(v),
})
);
}
};
useEffect(() => {
if (

View File

@@ -11,7 +11,7 @@ import {
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/modelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

View File

@@ -11,7 +11,7 @@ import {
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/modelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

View File

@@ -0,0 +1,19 @@
import {
T2IAdapterInputFieldTemplate,
T2IAdapterInputFieldValue,
T2IAdapterPolymorphicInputFieldTemplate,
T2IAdapterPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const T2IAdapterInputFieldComponent = (
_props: FieldComponentProps<
T2IAdapterInputFieldValue | T2IAdapterPolymorphicInputFieldValue,
T2IAdapterInputFieldTemplate | T2IAdapterPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(T2IAdapterInputFieldComponent);

View File

@@ -0,0 +1,14 @@
import {
UNetInputFieldTemplate,
UNetInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const UNetInputFieldComponent = (
_props: FieldComponentProps<UNetInputFieldValue, UNetInputFieldTemplate>
) => {
return null;
};
export default memo(UNetInputFieldComponent);

View File

@@ -0,0 +1,14 @@
import {
VaeInputFieldTemplate,
VaeInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const VaeInputFieldComponent = (
_props: FieldComponentProps<VaeInputFieldValue, VaeInputFieldTemplate>
) => {
return null;
};
export default memo(VaeInputFieldComponent);

View File

@@ -0,0 +1,27 @@
import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
import { memo } from 'react';
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
// not using it for now
const NodeResizer = (props: NodeResizerProps) => {
const { ...rest } = props;
return (
<NodeResizeControl
style={{
position: 'absolute',
border: 'none',
background: 'transparent',
width: 15,
height: 15,
bottom: 0,
right: 0,
}}
minWidth={NODE_MIN_WIDTH}
{...rest}
></NodeResizeControl>
);
};
export default memo(NodeResizer);

View File

@@ -0,0 +1,78 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvocationTemplate, NodeData } from 'features/nodes/types/types';
import { memo } from 'react';
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
import NodeTitle from '../../flow/nodes/common/NodeTitle';
import ScrollableContent from '../ScrollableContent';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
data: lastSelectedNode?.data,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorDetailsTab = () => {
const { data, template } = useAppSelector(selector);
const { t } = useTranslation();
if (!template || !data) {
return (
<IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />
);
}
return <Content data={data} template={template} />;
};
export default memo(InspectorDetailsTab);
const Content = (props: { data: NodeData; template: InvocationTemplate }) => {
const { data } = props;
return (
<Box
sx={{
position: 'relative',
w: 'full',
h: 'full',
}}
>
<ScrollableContent>
<Flex
sx={{
flexDir: 'column',
position: 'relative',
p: 1,
gap: 2,
w: 'full',
}}
>
<NodeTitle nodeId={data.id} />
<NotesTextarea nodeId={data.id} />
</Flex>
</ScrollableContent>
</Box>
);
};

View File

@@ -0,0 +1,51 @@
import { Box, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { workflowNotesChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { notes } = nodes.workflow;
return {
notes,
};
});
const WorkflowNotesTab = () => {
const { notes } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChangeNotes = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(workflowNotesChanged(e.target.value));
},
[dispatch]
);
return (
<Box sx={{ pos: 'relative', h: 'full' }}>
<IAITextarea
onChange={handleChangeNotes}
value={notes}
fontSize="sm"
sx={{ h: 'full', resize: 'none' }}
/>
<Box sx={{ pos: 'absolute', bottom: 2, right: 2 }}>
<Text
sx={{
fontSize: 'xs',
opacity: 0.5,
userSelect: 'none',
}}
>
{notes.length}
</Text>
</Box>
</Box>
);
};
export default memo(WorkflowNotesTab);

View File

@@ -0,0 +1,11 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnyInvocationType } from 'services/events/types';
export const makeTemplateSelector = (type: AnyInvocationType) =>
createSelector(
stateSelector,
({ nodes }) => nodes.nodeTemplates[type],
defaultSelectorOptions
);

View File

@@ -0,0 +1,21 @@
import { Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
const ParamBoundingBoxCollapse = () => {
const { t } = useTranslation();
return (
<IAICollapse label={t('parameters.boundingBoxHeader')}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamBoundingBoxWidth />
<ParamBoundingBoxHeight />
</Flex>
</IAICollapse>
);
};
export default memo(ParamBoundingBoxCollapse);

View File

@@ -6,7 +6,7 @@ import IAISlider from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -36,28 +36,25 @@ const ParamBoundingBoxWidth = () => {
? 1024
: 512;
const handleChangeHeight = useCallback(
(v: number) => {
const handleChangeHeight = (v: number) => {
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
height: Math.floor(v),
})
);
if (aspectRatio) {
const newWidth = roundToMultiple(v * aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
width: newWidth,
height: Math.floor(v),
})
);
if (aspectRatio) {
const newWidth = roundToMultiple(v * aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: newWidth,
height: Math.floor(v),
})
);
}
},
[aspectRatio, boundingBoxDimensions, dispatch]
);
}
};
const handleResetHeight = useCallback(() => {
const handleResetHeight = () => {
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
@@ -73,7 +70,7 @@ const ParamBoundingBoxWidth = () => {
})
);
}
}, [aspectRatio, boundingBoxDimensions, dispatch, initial]);
};
return (
<IAISlider

View File

@@ -6,7 +6,7 @@ import IAISlider from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -36,28 +36,25 @@ const ParamBoundingBoxWidth = () => {
const { t } = useTranslation();
const handleChangeWidth = useCallback(
(v: number) => {
const handleChangeWidth = (v: number) => {
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
width: Math.floor(v),
})
);
if (aspectRatio) {
const newHeight = roundToMultiple(v / aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
width: Math.floor(v),
height: newHeight,
})
);
if (aspectRatio) {
const newHeight = roundToMultiple(v / aspectRatio, 64);
dispatch(
setBoundingBoxDimensions({
width: Math.floor(v),
height: newHeight,
})
);
}
},
[aspectRatio, boundingBoxDimensions, dispatch]
);
}
};
const handleResetWidth = useCallback(() => {
const handleResetWidth = () => {
dispatch(
setBoundingBoxDimensions({
...boundingBoxDimensions,
@@ -73,7 +70,7 @@ const ParamBoundingBoxWidth = () => {
})
);
}
}, [aspectRatio, boundingBoxDimensions, dispatch, initial]);
};
return (
<IAISlider

View File

@@ -6,7 +6,7 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { setCanvasCoherenceMode } from 'features/parameters/store/generationSlice';
import { CanvasCoherenceModeParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const coherenceModeSelectData: IAISelectDataType[] = [
@@ -22,16 +22,13 @@ const ParamCanvasCoherenceMode = () => {
);
const { t } = useTranslation();
const handleCoherenceModeChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const handleCoherenceModeChange = (v: string | null) => {
if (!v) {
return;
}
dispatch(setCanvasCoherenceMode(v as CanvasCoherenceModeParam));
},
[dispatch]
);
dispatch(setCanvasCoherenceMode(v as CanvasCoherenceModeParam));
};
return (
<IAIInformationalPopover feature="compositingCoherenceMode">

View File

@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import { setCanvasCoherenceSteps } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamCanvasCoherenceSteps = () => {
@@ -13,17 +13,6 @@ const ParamCanvasCoherenceSteps = () => {
);
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => {
dispatch(setCanvasCoherenceSteps(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setCanvasCoherenceSteps(20));
}, [dispatch]);
return (
<IAIInformationalPopover feature="compositingCoherenceSteps">
<IAISlider
@@ -33,11 +22,15 @@ const ParamCanvasCoherenceSteps = () => {
step={1}
sliderNumberInputProps={{ max: 999 }}
value={canvasCoherenceSteps}
onChange={handleChange}
onChange={(v) => {
dispatch(setCanvasCoherenceSteps(v));
}}
withInput
withSliderMarks
withReset
handleReset={handleReset}
handleReset={() => {
dispatch(setCanvasCoherenceSteps(20));
}}
/>
</IAIInformationalPopover>
);

Some files were not shown because too many files have changed in this diff Show More