mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 06:28:06 -05:00
Compare commits
1 Commits
ryan/groun
...
ryan/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2387a5d686 |
@@ -55,7 +55,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@@ -60,12 +60,8 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@@ -502,33 +498,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_t2i_adapter_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
|
||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapters, T2IAdapterField):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -738,7 +707,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
|
||||
return mask, masked_latents, self.denoise_mask.gradient
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
@@ -796,6 +765,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
@@ -828,6 +801,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
@@ -845,40 +833,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
### seamless
|
||||
if self.unet.seamless_axes:
|
||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||
if unet_config.variant == ModelVariantType.Inpaint:
|
||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||
elif mask is not None:
|
||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||
|
||||
# Initialize context for modular denoise
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
@@ -886,7 +840,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
@@ -918,10 +871,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||
# We invert the mask here for compatibility with the old backend implementation.
|
||||
if mask is not None:
|
||||
mask = 1 - mask
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@@ -966,7 +915,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
|
||||
@@ -242,23 +242,6 @@ class ConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxField(BaseModel):
|
||||
"""A bounding box primitive value."""
|
||||
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
score: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
|
||||
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
||||
|
||||
|
||||
@invocation(
|
||||
"grounding_dino",
|
||||
title="Grounding DINO (Text Prompt Object Detection)",
|
||||
tags=["prompt", "object detection"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GroundingDinoInvocation(BaseInvocation):
|
||||
"""Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection
|
||||
from a text prompt.
|
||||
|
||||
Reference:
|
||||
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
detection_threshold: float = InputField(
|
||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.3,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||
# The model expects a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
detections = self._detect(
|
||||
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||
)
|
||||
|
||||
# Convert detections to BoundingBoxCollectionOutput.
|
||||
bounding_boxes: list[BoundingBoxField] = []
|
||||
for detection in detections:
|
||||
bounding_boxes.append(
|
||||
BoundingBoxField(
|
||||
x_min=detection.box.xmin,
|
||||
x_max=detection.box.xmax,
|
||||
y_min=detection.box.ymin,
|
||||
y_max=detection.box.ymax,
|
||||
score=detection.score,
|
||||
)
|
||||
)
|
||||
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||
|
||||
@staticmethod
|
||||
def _load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
labels: list[str],
|
||||
threshold: float = 0.3,
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -119,28 +118,3 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
height=mask.shape[1],
|
||||
width=mask.shape[2],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"tensor_mask_to_image",
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
|
||||
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
mask_np = mask.float().cpu().detach().numpy() * 255
|
||||
mask_np = mask_np.astype(np.uint8)
|
||||
|
||||
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
@@ -470,24 +469,3 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region BoundingBox
|
||||
|
||||
|
||||
@invocation_output("bounding_box_output")
|
||||
class BoundingBoxOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single bounding box"""
|
||||
|
||||
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||
|
||||
|
||||
@invocation_output("bounding_box_collection_output")
|
||||
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of bounding boxes"""
|
||||
|
||||
collection: list[BoundingBoxField] = OutputField(
|
||||
description="The output bounding boxes.",
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
85
invokeai/app/invocations/prompt_augmentation.py
Normal file
85
invokeai/app/invocations/prompt_augmentation.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
AUGMENT_PROMPT_INSTRUCTION = """Your task is to translate a short image caption and a style caption to a more detailed caption for the same image. The detailed caption should adhere to the following:
|
||||
- be 1 sentence long
|
||||
- use descriptive language that relates to the subject of interest
|
||||
- it may add new details, but shouldn't change the subject of the original caption
|
||||
Here are some examples:
|
||||
Original caption: "A cat on a table"
|
||||
Detailed caption: "A fluffy cat with a curious expression, sitting on a wooden table next to a vase of flowers."
|
||||
Original caption: "medieval armor"
|
||||
Detailed caption: "The gleaming suit of medieval armor stands proudly in the museum, its intricate engravings telling tales of long-forgotten battles and chivalry."
|
||||
Original caption: "A panda bear as a mad scientist"
|
||||
Detailed caption: "Clad in a tiny lab coat and goggles, the panda bear feverishly mixes colorful potions, embodying the eccentricity of a mad scientist in its whimsical laboratory."
|
||||
Here is the prompt to translate:
|
||||
Original caption: "{}"
|
||||
Detailed caption:"""
|
||||
|
||||
|
||||
@invocation("promp_augment", title="Prompt Augmentation", tags=["prompt"], category="conditioning", version="1.0.0")
|
||||
class PrompAugmentationInvocation(BaseInvocation):
|
||||
"""Use an LLM to augment a text prompt."""
|
||||
|
||||
prompt: str = InputField(description="The text prompt to augment.")
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
# TODO(ryand): Address the following situations in the input prompt:
|
||||
# - Prompt contains a TI embeddings.
|
||||
# - Prompt contains .and() compel syntax. (Is ther any other compel syntax we need to handle?)
|
||||
# - Prompt contains quotation marks that could cause confusion when embedded in an LLM instruct prompt.
|
||||
|
||||
# Load the model and tokenizer.
|
||||
model_source = "microsoft/Phi-3-mini-4k-instruct"
|
||||
|
||||
def model_loader(model_path: Path):
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=TorchDevice.choose_torch_dtype(), local_files_only=True
|
||||
)
|
||||
|
||||
def tokenizer_loader(model_path: Path):
|
||||
return AutoTokenizer.from_pretrained(model_path, local_files_only=True)
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(source=model_source, loader=model_loader) as model,
|
||||
context.models.load_remote_model(source=model_source, loader=tokenizer_loader) as tokenizer,
|
||||
):
|
||||
# Tokenize the input prompt.
|
||||
augmented_prompt = self._run_instruct_model(model, tokenizer, self.prompt)
|
||||
|
||||
return StringOutput(value=augmented_prompt)
|
||||
|
||||
def _run_instruct_model(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": AUGMENT_PROMPT_INSTRUCTION.format(prompt),
|
||||
}
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
max_new_tokens=200,
|
||||
temperature=0.9,
|
||||
do_sample=True,
|
||||
)
|
||||
text = tokenizer.batch_decode(outputs)[0]
|
||||
assert isinstance(text, str)
|
||||
|
||||
output = text.split("<|assistant|>")[-1].strip()
|
||||
output = output.split("<|end|>")[0].strip()
|
||||
|
||||
return output
|
||||
@@ -1,155 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
||||
|
||||
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_model",
|
||||
title="Segment Anything Model",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingModelInvocation(BaseInvocation):
|
||||
"""Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
|
||||
|
||||
Reference:
|
||||
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||
apply_polygon_refinement: bool = InputField(
|
||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
default=True,
|
||||
)
|
||||
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||
default="all",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if len(self.bounding_boxes) == 0:
|
||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
else:
|
||||
masks = self._segment(context=context, image=image_pil)
|
||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||
|
||||
# masks contains bool values, so we merge them via max-reduce.
|
||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
mask_tensor_name = context.tensors.save(combined_mask)
|
||||
height, width = combined_mask.shape
|
||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
# Convert the bounding boxes to the SAM input format.
|
||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
masks = self._apply_polygon_refinement(masks)
|
||||
|
||||
return masks
|
||||
|
||||
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||
"""
|
||||
assert masks.dtype == torch.bool
|
||||
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||
masks, _ = masks.max(dim=1)
|
||||
# Split the first dimension into a list of masks.
|
||||
return list(masks.cpu().unbind(dim=0))
|
||||
|
||||
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""Apply polygon refinement to the masks.
|
||||
|
||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||
- Smooth the edges of the mask slightly.
|
||||
- Ensure that each mask consists of a single closed polygon
|
||||
- Removes small mask pieces.
|
||||
- Removes holes from the mask.
|
||||
"""
|
||||
# Convert tensor masks to np masks.
|
||||
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||
|
||||
# Apply polygon refinement.
|
||||
for idx, mask in enumerate(np_masks):
|
||||
shape = mask.shape
|
||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
np_masks[idx] = mask
|
||||
|
||||
# Convert np masks back to tensor masks.
|
||||
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||
|
||||
return masks
|
||||
|
||||
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||
"""Filter the detected masks based on the specified mask filter."""
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
|
||||
if self.mask_filter == "all":
|
||||
return masks
|
||||
elif self.mask_filter == "largest":
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||
return [masks[max_score_idx]]
|
||||
else:
|
||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,22 +0,0 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box helper class."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Detection result from Grounding DINO."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
model_config = ConfigDict(
|
||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class GroundingDinoPipeline(RawModel):
|
||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||
management system.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||
self._pipeline = pipeline
|
||||
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
results = [DetectionResult.model_validate(result) for result in results]
|
||||
return results
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||
# CUDA.
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix the circular import issue.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -1,50 +0,0 @@
|
||||
# This file contains utilities for Grounded-SAM mask refinement based on:
|
||||
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
|
||||
"""Convert a binary mask to a polygon.
|
||||
|
||||
Returns:
|
||||
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||
"""
|
||||
# Find contours in the binary mask.
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the contour with the largest area.
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Extract the vertices of the contour.
|
||||
polygon = largest_contour.reshape(-1, 2).tolist()
|
||||
|
||||
return polygon
|
||||
|
||||
|
||||
def polygon_to_mask(
|
||||
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
|
||||
) -> npt.NDArray[np.uint8]:
|
||||
"""Convert a polygon to a segmentation mask.
|
||||
|
||||
Args:
|
||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||
fill_value (int): Value to fill the polygon with.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Segmentation mask with the polygon filled (with value 255).
|
||||
"""
|
||||
# Create an empty mask.
|
||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
# Convert polygon to an array of points.
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
|
||||
# Fill the polygon with white color (255).
|
||||
cv2.fillPoly(mask, [pts], color=(fill_value,))
|
||||
|
||||
return mask
|
||||
@@ -1,53 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SegmentAnythingModel(RawModel):
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
|
||||
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
||||
self._sam_model = sam_model
|
||||
self._sam_processor = sam_processor
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._sam_model.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix the circular import issue.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._sam_model)
|
||||
|
||||
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
# Add batch dimension of 1 to the bounding boxes.
|
||||
boxes = [bounding_boxes]
|
||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
masks=outputs.pred_masks,
|
||||
original_sizes=inputs.original_sizes,
|
||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||
)
|
||||
|
||||
# There should be only one batch.
|
||||
assert len(masks) == 1
|
||||
return masks[0]
|
||||
@@ -98,9 +98,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
try:
|
||||
|
||||
@@ -11,8 +11,6 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
@@ -36,17 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
TextualInversionModelRaw,
|
||||
IPAdapter,
|
||||
LoRAModelRaw,
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingModel,
|
||||
),
|
||||
):
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
|
||||
@@ -187,171 +187,164 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster v2 (SD1.5)",
|
||||
name="QRCode Monster",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="QRCode Monster (SDXL)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
|
||||
@@ -62,7 +62,13 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
#
|
||||
# On July 24, 2024, this regex filter was modified to support downloading the `microsoft/Phi-3-mini-4k-instruct`
|
||||
# model. I am making this note in case it is relevant as we continue to improve this logic and make it less
|
||||
# brittle.
|
||||
# - Before: r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||
# - After: r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
|
||||
@@ -7,9 +7,11 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"set_seamless",
|
||||
]
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class InpaintExt(ExtensionBase):
|
||||
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mask: torch.Tensor,
|
||||
is_gradient_mask: bool,
|
||||
):
|
||||
"""Initialize InpaintExt.
|
||||
Args:
|
||||
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||
inpainted.
|
||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||
1.
|
||||
"""
|
||||
super().__init__()
|
||||
self._mask = mask
|
||||
self._is_gradient_mask = is_gradient_mask
|
||||
|
||||
# Noise, which used to noisify unmasked part of image
|
||||
# if noise provided to context, then it will be used
|
||||
# if no noise provided, then noise will be generated based on seed
|
||||
self._noise: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def _is_normal_model(unet: UNet2DConditionModel):
|
||||
"""Checks if the provided UNet belongs to a regular model.
|
||||
The `in_channels` of a UNet vary depending on model type:
|
||||
- normal - 4
|
||||
- depth - 5
|
||||
- inpaint - 9
|
||||
"""
|
||||
return unet.conv_in.in_channels == 4
|
||||
|
||||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||
# get very confused about what is happening from step to step when we do that.
|
||||
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if self._is_gradient_mask:
|
||||
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask < 1 - threshold
|
||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||
else:
|
||||
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
||||
return masked_input
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def init_tensors(self, ctx: DenoiseContext):
|
||||
if not self._is_normal_model(ctx.unet):
|
||||
raise ValueError(
|
||||
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
|
||||
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
|
||||
"fixed by removing and re-adding the model (so that it gets re-probed)."
|
||||
)
|
||||
|
||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||
|
||||
self._noise = ctx.inputs.noise
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if self._noise is None:
|
||||
self._noise = torch.randn(
|
||||
ctx.latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
||||
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||
|
||||
# Use negative order to make extensions with default order work with patched latents
|
||||
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
||||
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
||||
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
||||
|
||||
# TODO: redo this with preview events rewrite
|
||||
# Use negative order to make extensions with default order work with patched latents
|
||||
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
||||
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
||||
timestep = ctx.scheduler.timesteps[-1]
|
||||
if hasattr(ctx.step_output, "denoised"):
|
||||
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
||||
else:
|
||||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
||||
|
||||
# Restore unmasked part after the last step is completed
|
||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||
def restore_unmasked(self, ctx: DenoiseContext):
|
||||
if self._is_gradient_mask:
|
||||
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
||||
else:
|
||||
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
||||
@@ -1,88 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class InpaintModelExt(ExtensionBase):
|
||||
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mask: Optional[torch.Tensor],
|
||||
masked_latents: Optional[torch.Tensor],
|
||||
is_gradient_mask: bool,
|
||||
):
|
||||
"""Initialize InpaintModelExt.
|
||||
Args:
|
||||
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||
inpainted.
|
||||
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||
1.
|
||||
"""
|
||||
super().__init__()
|
||||
if mask is not None and masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
|
||||
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
||||
self._mask = None
|
||||
if mask is not None:
|
||||
self._mask = 1 - mask
|
||||
self._masked_latents = masked_latents
|
||||
self._is_gradient_mask = is_gradient_mask
|
||||
|
||||
@staticmethod
|
||||
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||
"""Checks if the provided UNet belongs to a regular model.
|
||||
The `in_channels` of a UNet vary depending on model type:
|
||||
- normal - 4
|
||||
- depth - 5
|
||||
- inpaint - 9
|
||||
"""
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def init_tensors(self, ctx: DenoiseContext):
|
||||
if not self._is_inpaint_model(ctx.unet):
|
||||
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||
|
||||
if self._mask is None:
|
||||
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||
|
||||
if self._masked_latents is None:
|
||||
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
||||
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||
|
||||
# Do last so that other extensions works with normal latents
|
||||
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
||||
def append_inpaint_layers(self, ctx: DenoiseContext):
|
||||
batch_size = ctx.unet_kwargs.sample.shape[0]
|
||||
b_mask = torch.cat([self._mask] * batch_size)
|
||||
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
||||
ctx.unet_kwargs.sample = torch.cat(
|
||||
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Restore unmasked part as inpaint model can change unmasked part slightly
|
||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||
def restore_unmasked(self, ctx: DenoiseContext):
|
||||
if self._is_gradient_mask:
|
||||
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||
else:
|
||||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
||||
@@ -1,71 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
|
||||
class SeamlessExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
super().__init__()
|
||||
self._seamless_axes = seamless_axes
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
with self.static_patch_model(
|
||||
model=unet,
|
||||
seamless_axes=self._seamless_axes,
|
||||
):
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def static_patch_model(
|
||||
model: torch.nn.Module,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
try:
|
||||
for layer in model.modules():
|
||||
if not isinstance(layer, torch.nn.Conv2d):
|
||||
continue
|
||||
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
||||
@@ -1,120 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import T2IAdapter
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class T2IAdapterExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
node_context: InvocationContext,
|
||||
model_id: ModelIdentifierField,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._node_context = node_context
|
||||
self._model_id = model_id
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._resize_mode = resize_mode
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
self._adapter_state: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
model_config = self._node_context.models.get_config(self._model_id.key)
|
||||
if model_config.base == BaseModelType.StableDiffusion1:
|
||||
self._max_unet_downscale = 8
|
||||
elif model_config.base == BaseModelType.StableDiffusionXL:
|
||||
self._max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
||||
|
||||
@callback(ExtensionCallbackType.SETUP)
|
||||
def setup(self, ctx: DenoiseContext):
|
||||
t2i_model: T2IAdapter
|
||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||
|
||||
self._adapter_state = self._run_model(
|
||||
model=t2i_model,
|
||||
image=self._image,
|
||||
latents_height=latents_height,
|
||||
latents_width=latents_width,
|
||||
)
|
||||
|
||||
def _run_model(
|
||||
self,
|
||||
model: T2IAdapter,
|
||||
image: Image,
|
||||
latents_height: int,
|
||||
latents_width: int,
|
||||
):
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
||||
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=input_width,
|
||||
height=input_height,
|
||||
num_channels=model.config["in_channels"],
|
||||
device=model.device,
|
||||
dtype=model.dtype,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
return model(t2i_image)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
adapter_state = self._adapter_state
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
||||
|
||||
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
||||
else:
|
||||
for i, value in enumerate(adapter_state):
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
||||
51
invokeai/backend/stable_diffusion/seamless.py
Normal file
51
invokeai/backend/stable_diffusion/seamless.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
|
||||
try:
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
conv_layers: List[torch.nn.Conv2d] = []
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
conv_layers.append(module)
|
||||
|
||||
for layer in conv_layers:
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
||||
@@ -77,6 +77,10 @@
|
||||
"title": "استعادة الوجوه",
|
||||
"desc": "استعادة الصورة الحالية"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "تحسين الحجم",
|
||||
"desc": "تحسين حجم الصورة الحالية"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "عرض المعلومات",
|
||||
"desc": "عرض معلومات البيانات الخاصة بالصورة الحالية"
|
||||
@@ -251,6 +255,8 @@
|
||||
"type": "نوع",
|
||||
"strength": "قوة",
|
||||
"upscaling": "تصغير",
|
||||
"upscale": "تصغير",
|
||||
"upscaleImage": "تصغير الصورة",
|
||||
"scale": "مقياس",
|
||||
"imageFit": "ملائمة الصورة الأولية لحجم الخرج",
|
||||
"scaleBeforeProcessing": "تحجيم قبل المعالجة",
|
||||
|
||||
@@ -187,6 +187,10 @@
|
||||
"title": "Gesicht restaurieren",
|
||||
"desc": "Das aktuelle Bild restaurieren"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Hochskalieren",
|
||||
"desc": "Das aktuelle Bild hochskalieren"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Info anzeigen",
|
||||
"desc": "Metadaten des aktuellen Bildes anzeigen"
|
||||
@@ -429,6 +433,8 @@
|
||||
"type": "Art",
|
||||
"strength": "Stärke",
|
||||
"upscaling": "Hochskalierung",
|
||||
"upscale": "Hochskalieren (Shift + U)",
|
||||
"upscaleImage": "Bild hochskalieren",
|
||||
"scale": "Maßstab",
|
||||
"imageFit": "Ausgangsbild an Ausgabegröße anpassen",
|
||||
"scaleBeforeProcessing": "Skalieren vor der Verarbeitung",
|
||||
|
||||
@@ -105,7 +105,6 @@
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"discordLabel": "Discord",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"dontShowMeThese": "Don't show me these",
|
||||
"editor": "Editor",
|
||||
"error": "Error",
|
||||
"file": "File",
|
||||
@@ -382,9 +381,7 @@
|
||||
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
||||
"galleryImageSize": "Image Size",
|
||||
"gallerySettings": "Gallery Settings",
|
||||
"go": "Go",
|
||||
"image": "image",
|
||||
"jump": "Jump",
|
||||
"loading": "Loading",
|
||||
"loadMore": "Load More",
|
||||
"newestFirst": "Newest First",
|
||||
@@ -1100,8 +1097,6 @@
|
||||
"displayInProgress": "Display Progress Images",
|
||||
"enableImageDebugging": "Enable Image Debugging",
|
||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||
"enableInvisibleWatermark": "Enable Invisible Watermark",
|
||||
"enableNSFWChecker": "Enable NSFW Checker",
|
||||
"general": "General",
|
||||
@@ -1509,30 +1504,6 @@
|
||||
"seamlessTilingYAxis": {
|
||||
"heading": "Seamless Tiling Y Axis",
|
||||
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
|
||||
},
|
||||
"upscaleModel": {
|
||||
"heading": "Upscale Model",
|
||||
"paragraphs": [
|
||||
"The upscale model scales the image to the output size before details are added. Any supported upscale model may be used, but some are specialized for different kinds of images, like photos or line drawings."
|
||||
]
|
||||
},
|
||||
"scale": {
|
||||
"heading": "Scale",
|
||||
"paragraphs": [
|
||||
"Scale controls the output image size, and is based on a multiple of the input image resolution. For example a 2x upscale on a 1024x1024 image would produce a 2048 x 2048 output."
|
||||
]
|
||||
},
|
||||
"creativity": {
|
||||
"heading": "Creativity",
|
||||
"paragraphs": [
|
||||
"Creativity controls the amount of freedom granted to the model when adding details. Low creativity stays close to the original image, while high creativity allows for more change. When using a prompt, high creativity increases the influence of the prompt."
|
||||
]
|
||||
},
|
||||
"structure": {
|
||||
"heading": "Structure",
|
||||
"paragraphs": [
|
||||
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||
]
|
||||
}
|
||||
},
|
||||
"unifiedCanvas": {
|
||||
|
||||
@@ -151,6 +151,10 @@
|
||||
"title": "Restaurar rostros",
|
||||
"desc": "Restaurar rostros en la imagen actual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Aumentar resolución",
|
||||
"desc": "Aumentar la resolución de la imagen actual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar información",
|
||||
"desc": "Mostar metadatos de la imagen actual"
|
||||
@@ -356,6 +360,8 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Fuerza",
|
||||
"upscaling": "Aumento de resolución",
|
||||
"upscale": "Aumentar resolución",
|
||||
"upscaleImage": "Aumentar la resolución de la imagen",
|
||||
"scale": "Escala",
|
||||
"imageFit": "Ajuste tamaño de imagen inicial al tamaño objetivo",
|
||||
"scaleBeforeProcessing": "Redimensionar antes de procesar",
|
||||
@@ -402,12 +408,7 @@
|
||||
"showProgressInViewer": "Mostrar las imágenes del progreso en el visor",
|
||||
"ui": "Interfaz del usuario",
|
||||
"generation": "Generación",
|
||||
"beta": "Beta",
|
||||
"reloadingIn": "Recargando en",
|
||||
"intermediatesClearedFailed": "Error limpiando los intermediarios",
|
||||
"intermediatesCleared_one": "Borrado {{count}} intermediario",
|
||||
"intermediatesCleared_many": "Borrados {{count}} intermediarios",
|
||||
"intermediatesCleared_other": "Borrados {{count}} intermediarios"
|
||||
"beta": "Beta"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Error al subir archivo",
|
||||
@@ -425,12 +426,7 @@
|
||||
"parameterSet": "Conjunto de parámetros",
|
||||
"parameterNotSet": "Parámetro no configurado",
|
||||
"problemCopyingImage": "No se puede copiar la imagen",
|
||||
"errorCopied": "Error al copiar",
|
||||
"baseModelChanged": "Modelo base cambiado",
|
||||
"addedToBoard": "Añadido al tablero",
|
||||
"baseModelChangedCleared_one": "Borrado o desactivado {{count}} submodelo incompatible",
|
||||
"baseModelChangedCleared_many": "Borrados o desactivados {{count}} submodelos incompatibles",
|
||||
"baseModelChangedCleared_other": "Borrados o desactivados {{count}} submodelos incompatibles"
|
||||
"errorCopied": "Error al copiar"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -544,13 +540,7 @@
|
||||
"downloadBoard": "Descargar panel",
|
||||
"deleteBoardOnly": "Borrar solo el panel",
|
||||
"myBoard": "Mi panel",
|
||||
"noMatching": "No hay paneles que coincidan",
|
||||
"imagesWithCount_one": "{{count}} imagen",
|
||||
"imagesWithCount_many": "{{count}} imágenes",
|
||||
"imagesWithCount_other": "{{count}} imágenes",
|
||||
"assetsWithCount_one": "{{count}} activo",
|
||||
"assetsWithCount_many": "{{count}} activos",
|
||||
"assetsWithCount_other": "{{count}} activos"
|
||||
"noMatching": "No hay paneles que coincidan"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
@@ -600,27 +590,6 @@
|
||||
"balanced": "Equilibrado",
|
||||
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
|
||||
"detectResolution": "Detectar resolución",
|
||||
"beginEndStepPercentShort": "Inicio / Final %",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"addControlNet": "Añadir $t(common.controlNet)",
|
||||
"addIPAdapter": "Añadir $t(common.ipAdapter)",
|
||||
"controlAdapter_one": "Adaptador de control",
|
||||
"controlAdapter_many": "Adaptadores de control",
|
||||
"controlAdapter_other": "Adaptadores de control",
|
||||
"addT2IAdapter": "Añadir $t(common.t2iAdapter)"
|
||||
},
|
||||
"queue": {
|
||||
"back": "Atrás",
|
||||
"front": "Delante",
|
||||
"batchQueuedDesc_one": "Se agregó {{count}} sesión a {{direction}} la cola",
|
||||
"batchQueuedDesc_many": "Se agregaron {{count}} sesiones a {{direction}} la cola",
|
||||
"batchQueuedDesc_other": "Se agregaron {{count}} sesiones a {{direction}} la cola"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invitar compañeros de equipo",
|
||||
"shareAccess": "Compartir acceso",
|
||||
"professionalUpsell": "Disponible en la edición profesional de Invoke. Haz clic aquí o visita invoke.com/pricing para obtener más detalles."
|
||||
"beginEndStepPercentShort": "Inicio / Final %"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,6 +130,10 @@
|
||||
"title": "Restaurer les visages",
|
||||
"desc": "Restaurer l'image actuelle"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Agrandir",
|
||||
"desc": "Agrandir l'image actuelle"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Afficher les informations",
|
||||
"desc": "Afficher les informations de métadonnées de l'image actuelle"
|
||||
@@ -304,6 +308,8 @@
|
||||
"type": "Type",
|
||||
"strength": "Force",
|
||||
"upscaling": "Agrandissement",
|
||||
"upscale": "Agrandir",
|
||||
"upscaleImage": "Image en Agrandissement",
|
||||
"scale": "Echelle",
|
||||
"imageFit": "Ajuster Image Initiale à la Taille de Sortie",
|
||||
"scaleBeforeProcessing": "Echelle Avant Traitement",
|
||||
|
||||
@@ -90,6 +90,10 @@
|
||||
"desc": "שחזור התמונה הנוכחית",
|
||||
"title": "שחזור פרצופים"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "הגדלת קנה מידה",
|
||||
"desc": "הגדל את התמונה הנוכחית"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "הצג מידע",
|
||||
"desc": "הצגת פרטי מטא-נתונים של התמונה הנוכחית"
|
||||
@@ -259,6 +263,8 @@
|
||||
"seed": "זרע",
|
||||
"type": "סוג",
|
||||
"strength": "חוזק",
|
||||
"upscale": "הגדלת קנה מידה",
|
||||
"upscaleImage": "הגדלת קנה מידת התמונה",
|
||||
"denoisingStrength": "חוזק מנטרל הרעש",
|
||||
"scaleBeforeProcessing": "שנה קנה מידה לפני עיבוד",
|
||||
"scaledWidth": "קנה מידה לאחר שינוי W",
|
||||
|
||||
@@ -150,11 +150,7 @@
|
||||
"showArchivedBoards": "Mostra le bacheche archiviate",
|
||||
"searchImages": "Ricerca per metadati",
|
||||
"displayBoardSearch": "Mostra la ricerca nelle Bacheche",
|
||||
"displaySearch": "Mostra la ricerca",
|
||||
"selectAllOnPage": "Seleziona tutto nella pagina",
|
||||
"selectAllOnBoard": "Seleziona tutto nella bacheca",
|
||||
"exitBoardSearch": "Esci da Ricerca bacheca",
|
||||
"exitSearch": "Esci dalla ricerca"
|
||||
"displaySearch": "Mostra la ricerca"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||
@@ -214,6 +210,10 @@
|
||||
"title": "Restaura volti",
|
||||
"desc": "Restaura l'immagine corrente"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Amplia",
|
||||
"desc": "Amplia l'immagine corrente"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostra informazioni",
|
||||
"desc": "Mostra le informazioni sui metadati dell'immagine corrente"
|
||||
@@ -377,10 +377,6 @@
|
||||
"toggleViewer": {
|
||||
"title": "Attiva/disattiva il visualizzatore di immagini",
|
||||
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||
},
|
||||
"postProcess": {
|
||||
"desc": "Elabora l'immagine corrente utilizzando il modello di post-elaborazione selezionato",
|
||||
"title": "Elabora immagine"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@@ -509,6 +505,8 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Forza",
|
||||
"upscaling": "Ampliamento",
|
||||
"upscale": "Amplia (Shift + U)",
|
||||
"upscaleImage": "Amplia Immagine",
|
||||
"scale": "Scala",
|
||||
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
|
||||
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
|
||||
@@ -593,10 +591,7 @@
|
||||
"infillColorValue": "Colore di riempimento",
|
||||
"globalSettings": "Impostazioni globali",
|
||||
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale",
|
||||
"processImage": "Elabora Immagine",
|
||||
"sendToUpscale": "Invia a Ampliare",
|
||||
"postProcessing": "Post-elaborazione (Shift + U)"
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@@ -969,10 +964,7 @@
|
||||
"boards": "Bacheche",
|
||||
"private": "Bacheche private",
|
||||
"shared": "Bacheche condivise",
|
||||
"addPrivateBoard": "Aggiungi una Bacheca Privata",
|
||||
"noBoards": "Nessuna bacheca {{boardType}}",
|
||||
"hideBoards": "Nascondi bacheche",
|
||||
"viewBoards": "Visualizza bacheche"
|
||||
"addPrivateBoard": "Aggiungi una Bacheca Privata"
|
||||
},
|
||||
"controlnet": {
|
||||
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
|
||||
@@ -1692,30 +1684,7 @@
|
||||
"models": "Modelli",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Coda",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"upscaling": "Ampliamento",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
"creativity": "Creatività",
|
||||
"structure": "Struttura",
|
||||
"upscaleModel": "Modello di Ampliamento",
|
||||
"scale": "Scala",
|
||||
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
|
||||
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
|
||||
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
|
||||
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
|
||||
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
|
||||
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
||||
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
||||
"postProcessingModel": "Modello di post-elaborazione",
|
||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invita collaboratori",
|
||||
"shareAccess": "Condividi l'accesso",
|
||||
"professional": "Professionale",
|
||||
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,6 +199,10 @@
|
||||
"title": "顔の修復",
|
||||
"desc": "現在の画像を修復"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "アップスケール",
|
||||
"desc": "現在の画像をアップスケール"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "情報を見る",
|
||||
"desc": "現在の画像のメタデータ情報を表示"
|
||||
@@ -423,6 +427,8 @@
|
||||
"shuffle": "シャッフル",
|
||||
"strength": "強度",
|
||||
"upscaling": "アップスケーリング",
|
||||
"upscale": "アップスケール",
|
||||
"upscaleImage": "画像をアップスケール",
|
||||
"scale": "Scale",
|
||||
"scaleBeforeProcessing": "処理前のスケール",
|
||||
"scaledWidth": "幅のスケール",
|
||||
|
||||
@@ -258,6 +258,10 @@
|
||||
"desc": "캔버스 브러시를 선택",
|
||||
"title": "브러시 선택"
|
||||
},
|
||||
"upscale": {
|
||||
"desc": "현재 이미지를 업스케일",
|
||||
"title": "업스케일"
|
||||
},
|
||||
"previousImage": {
|
||||
"title": "이전 이미지",
|
||||
"desc": "갤러리에 이전 이미지 표시"
|
||||
|
||||
@@ -168,6 +168,10 @@
|
||||
"title": "Herstel gezichten",
|
||||
"desc": "Herstelt de huidige afbeelding"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Schaal op",
|
||||
"desc": "Schaalt de huidige afbeelding op"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Toon info",
|
||||
"desc": "Toont de metagegevens van de huidige afbeelding"
|
||||
@@ -408,6 +412,8 @@
|
||||
"type": "Soort",
|
||||
"strength": "Sterkte",
|
||||
"upscaling": "Opschalen",
|
||||
"upscale": "Vergroot (Shift + U)",
|
||||
"upscaleImage": "Schaal afbeelding op",
|
||||
"scale": "Schaal",
|
||||
"imageFit": "Pas initiële afbeelding in uitvoergrootte",
|
||||
"scaleBeforeProcessing": "Schalen voor verwerking",
|
||||
|
||||
@@ -78,6 +78,10 @@
|
||||
"title": "Popraw twarze",
|
||||
"desc": "Uruchamia proces poprawiania twarzy dla aktywnego obrazu"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Powiększ",
|
||||
"desc": "Uruchamia proces powiększania aktywnego obrazu"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Pokaż informacje",
|
||||
"desc": "Pokazuje metadane zapisane w aktywnym obrazie"
|
||||
@@ -228,6 +232,8 @@
|
||||
"type": "Metoda",
|
||||
"strength": "Siła",
|
||||
"upscaling": "Powiększanie",
|
||||
"upscale": "Powiększ",
|
||||
"upscaleImage": "Powiększ obraz",
|
||||
"scale": "Skala",
|
||||
"imageFit": "Przeskaluj oryginalny obraz",
|
||||
"scaleBeforeProcessing": "Tryb skalowania",
|
||||
|
||||
@@ -160,6 +160,10 @@
|
||||
"title": "Restaurar Rostos",
|
||||
"desc": "Restaurar a imagem atual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Redimensionar",
|
||||
"desc": "Redimensionar a imagem atual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar Informações",
|
||||
"desc": "Mostrar metadados de informações da imagem atual"
|
||||
@@ -271,6 +275,8 @@
|
||||
"showOptionsPanel": "Mostrar Painel de Opções",
|
||||
"strength": "Força",
|
||||
"upscaling": "Redimensionando",
|
||||
"upscale": "Redimensionar",
|
||||
"upscaleImage": "Redimensionar Imagem",
|
||||
"scaleBeforeProcessing": "Escala Antes do Processamento",
|
||||
"images": "Imagems",
|
||||
"steps": "Passos",
|
||||
|
||||
@@ -80,6 +80,10 @@
|
||||
"title": "Restaurar Rostos",
|
||||
"desc": "Restaurar a imagem atual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Redimensionar",
|
||||
"desc": "Redimensionar a imagem atual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar Informações",
|
||||
"desc": "Mostrar metadados de informações da imagem atual"
|
||||
@@ -264,6 +268,8 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Força",
|
||||
"upscaling": "Redimensionando",
|
||||
"upscale": "Redimensionar",
|
||||
"upscaleImage": "Redimensionar Imagem",
|
||||
"scale": "Escala",
|
||||
"imageFit": "Caber Imagem Inicial No Tamanho de Saída",
|
||||
"scaleBeforeProcessing": "Escala Antes do Processamento",
|
||||
|
||||
@@ -214,6 +214,10 @@
|
||||
"title": "Восстановить лица",
|
||||
"desc": "Восстановить лица на текущем изображении"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Увеличение",
|
||||
"desc": "Увеличить текущеее изображение"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Показать метаданные",
|
||||
"desc": "Показать метаданные из текущего изображения"
|
||||
@@ -508,6 +512,8 @@
|
||||
"type": "Тип",
|
||||
"strength": "Сила",
|
||||
"upscaling": "Увеличение",
|
||||
"upscale": "Увеличить",
|
||||
"upscaleImage": "Увеличить изображение",
|
||||
"scale": "Масштаб",
|
||||
"imageFit": "Уместить изображение",
|
||||
"scaleBeforeProcessing": "Масштабировать",
|
||||
|
||||
@@ -90,6 +90,10 @@
|
||||
"title": "Återskapa ansikten",
|
||||
"desc": "Återskapa nuvarande bild"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Skala upp",
|
||||
"desc": "Skala upp nuvarande bild"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Visa info",
|
||||
"desc": "Visa metadata för nuvarande bild"
|
||||
|
||||
@@ -416,6 +416,10 @@
|
||||
"desc": "Maske/Taban katmanları arasında geçiş yapar",
|
||||
"title": "Katmanı Gizle-Göster"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Büyüt",
|
||||
"desc": "Seçili görseli büyüt"
|
||||
},
|
||||
"setSeed": {
|
||||
"title": "Tohumu Kullan",
|
||||
"desc": "Seçili görselin tohumunu kullan"
|
||||
@@ -637,6 +641,7 @@
|
||||
"copyImage": "Görseli Kopyala",
|
||||
"height": "Boy",
|
||||
"width": "En",
|
||||
"upscale": "Büyüt (Shift + U)",
|
||||
"useSize": "Boyutu Kullan",
|
||||
"symmetry": "Bakışım",
|
||||
"tileSize": "Döşeme Boyutu",
|
||||
@@ -652,6 +657,7 @@
|
||||
"showOptionsPanel": "Yan Paneli Göster (O ya da T)",
|
||||
"shuffle": "Kar",
|
||||
"usePrompt": "İstemi Kullan",
|
||||
"upscaleImage": "Görseli Büyüt",
|
||||
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (çok küçük olabilir)",
|
||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (çok büyük olabilir)",
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
|
||||
@@ -85,6 +85,10 @@
|
||||
"title": "Відновити обличчя",
|
||||
"desc": "Відновити обличчя на поточному зображенні"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Збільшення",
|
||||
"desc": "Збільшити поточне зображення"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Показати метадані",
|
||||
"desc": "Показати метадані з поточного зображення"
|
||||
@@ -272,6 +276,8 @@
|
||||
"type": "Тип",
|
||||
"strength": "Сила",
|
||||
"upscaling": "Збільшення",
|
||||
"upscale": "Збільшити",
|
||||
"upscaleImage": "Збільшити зображення",
|
||||
"scale": "Масштаб",
|
||||
"imageFit": "Вмістити зображення",
|
||||
"scaleBeforeProcessing": "Масштабувати",
|
||||
|
||||
@@ -193,6 +193,10 @@
|
||||
"title": "面部修复",
|
||||
"desc": "对当前图像进行面部修复"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "放大",
|
||||
"desc": "对当前图像进行放大"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "显示信息",
|
||||
"desc": "显示当前图像的元数据"
|
||||
@@ -418,6 +422,8 @@
|
||||
"type": "种类",
|
||||
"strength": "强度",
|
||||
"upscaling": "放大",
|
||||
"upscale": "放大 (Shift + U)",
|
||||
"upscaleImage": "放大图像",
|
||||
"scale": "等级",
|
||||
"imageFit": "使生成图像长宽适配初始图像",
|
||||
"scaleBeforeProcessing": "处理前缩放",
|
||||
|
||||
@@ -10,32 +10,32 @@ import {
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
// Type inference doesn't work for this if you inline it in the listener for some reason
|
||||
const matchAnyBoardDeleted = isAnyOf(
|
||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||
);
|
||||
|
||||
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
||||
/**
|
||||
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
||||
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
||||
*/
|
||||
startAppListening({
|
||||
matcher: matchAnyBoardDeleted,
|
||||
matcher: isAnyOf(
|
||||
// If a board is deleted, we'll need to reset the auto-add board
|
||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||
),
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const deletedBoardId = action.meta.arg.originalArgs;
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||
|
||||
// If the deleted board was currently selected, we should reset the selected board to uncategorized
|
||||
if (deletedBoardId === selectedBoardId) {
|
||||
if (!queryResult.data) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) {
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
|
||||
// If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized
|
||||
if (deletedBoardId === autoAddBoardId) {
|
||||
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
}
|
||||
},
|
||||
@@ -46,8 +46,14 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
const { shouldShowArchivedBoards } = state.gallery;
|
||||
|
||||
if (!queryResult.data) {
|
||||
return;
|
||||
}
|
||||
|
||||
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
||||
|
||||
if (wasArchived && !shouldShowArchivedBoards) {
|
||||
@@ -65,7 +71,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
const shouldShowArchivedBoards = action.payload;
|
||||
|
||||
// We only need to take action if we have just hidden archived boards.
|
||||
if (shouldShowArchivedBoards) {
|
||||
if (!shouldShowArchivedBoards) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -80,16 +86,14 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
|
||||
// Handle the case where selected board is archived
|
||||
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
||||
if (!selectedBoard || selectedBoard.archived) {
|
||||
// If we can't find the selected board or it's archived, we should reset the selected board to uncategorized
|
||||
if (selectedBoard && selectedBoard.archived) {
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
|
||||
// Handle the case where auto-add board is archived
|
||||
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
||||
if (!autoAddBoard || autoAddBoard.archived) {
|
||||
// If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized
|
||||
if (autoAddBoard && autoAddBoard.archived) {
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
}
|
||||
},
|
||||
|
||||
@@ -10,12 +10,9 @@ import {
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Portal,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { ReactElement } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -74,7 +71,7 @@ type ContentProps = {
|
||||
|
||||
const Content = ({ data, feature }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
|
||||
|
||||
const paragraphs = useMemo<string[]>(
|
||||
@@ -85,25 +82,16 @@ const Content = ({ data, feature }: ContentProps) => {
|
||||
[feature, t]
|
||||
);
|
||||
|
||||
const onClickLearnMore = useCallback(() => {
|
||||
const handleClick = useCallback(() => {
|
||||
if (!data?.href) {
|
||||
return;
|
||||
}
|
||||
window.open(data.href);
|
||||
}, [data?.href]);
|
||||
|
||||
const onClickDontShowMeThese = useCallback(() => {
|
||||
dispatch(setShouldEnableInformationalPopovers(false));
|
||||
toast({
|
||||
title: t('settings.informationalPopoversDisabled'),
|
||||
description: t('settings.informationalPopoversDisabledDesc'),
|
||||
status: 'info',
|
||||
});
|
||||
}, [dispatch, t]);
|
||||
|
||||
return (
|
||||
<PopoverContent maxW={300}>
|
||||
<PopoverCloseButton top={2} />
|
||||
<PopoverContent w={96}>
|
||||
<PopoverCloseButton />
|
||||
<PopoverBody>
|
||||
<Flex gap={2} flexDirection="column" alignItems="flex-start">
|
||||
{heading && (
|
||||
@@ -128,19 +116,20 @@ const Content = ({ data, feature }: ContentProps) => {
|
||||
{paragraphs.map((p) => (
|
||||
<Text key={p}>{p}</Text>
|
||||
))}
|
||||
|
||||
<Divider />
|
||||
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
|
||||
{t('common.dontShowMeThese')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
{data?.href && (
|
||||
<Button onClick={onClickLearnMore} leftIcon={<PiArrowSquareOutBold />} variant="link" size="sm">
|
||||
{data?.href && (
|
||||
<>
|
||||
<Divider />
|
||||
<Button
|
||||
pt={1}
|
||||
onClick={handleClick}
|
||||
leftIcon={<PiArrowSquareOutBold />}
|
||||
alignSelf="flex-end"
|
||||
variant="link"
|
||||
>
|
||||
{t('common.learnMore') ?? heading}
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
|
||||
@@ -53,11 +53,7 @@ export type Feature =
|
||||
| 'refinerCfgScale'
|
||||
| 'scaleBeforeProcessing'
|
||||
| 'seamlessTilingXAxis'
|
||||
| 'seamlessTilingYAxis'
|
||||
| 'upscaleModel'
|
||||
| 'scale'
|
||||
| 'creativity'
|
||||
| 'structure';
|
||||
| 'seamlessTilingYAxis';
|
||||
|
||||
export type PopoverData = PopoverProps & {
|
||||
image?: string;
|
||||
|
||||
@@ -3,8 +3,6 @@ import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGaller
|
||||
import { useCallback } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import { JumpTo } from './JumpTo';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
|
||||
useGalleryPagination();
|
||||
@@ -22,7 +20,7 @@ export const GalleryPagination = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex justifyContent="center" alignItems="center" w="full" gap={1} pt={2}>
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
@@ -32,9 +30,25 @@ export const GalleryPagination = () => {
|
||||
variant="ghost"
|
||||
/>
|
||||
<Spacer />
|
||||
{pageButtons.map((page, i) => (
|
||||
<PageButton key={`${page}_${i}`} page={page} currentPage={currentPage} goToPage={goToPage} />
|
||||
))}
|
||||
{pageButtons.map((page, i) => {
|
||||
if (page === ELLIPSIS) {
|
||||
return (
|
||||
<Button size="sm" key={`ellipsis_${i}`} variant="link" isDisabled>
|
||||
...
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
size="sm"
|
||||
key={page}
|
||||
onClick={goToPage.bind(null, page - 1)}
|
||||
variant={currentPage === page - 1 ? 'solid' : 'outline'}
|
||||
>
|
||||
{page}
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
@@ -44,28 +58,6 @@ export const GalleryPagination = () => {
|
||||
isDisabled={!isNextEnabled}
|
||||
variant="ghost"
|
||||
/>
|
||||
<JumpTo />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type PageButtonProps = {
|
||||
page: number | typeof ELLIPSIS;
|
||||
currentPage: number;
|
||||
goToPage: (page: number) => void;
|
||||
};
|
||||
|
||||
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
|
||||
if (page === ELLIPSIS) {
|
||||
return (
|
||||
<Button size="sm" variant="link" isDisabled>
|
||||
...
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Button size="sm" onClick={goToPage.bind(null, page - 1)} variant={currentPage === page - 1 ? 'solid' : 'outline'}>
|
||||
{page}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Flex,
|
||||
FormControl,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const JumpTo = () => {
|
||||
const { t } = useTranslation();
|
||||
const { goToPage, currentPage, pages } = useGalleryPagination();
|
||||
const [newPage, setNewPage] = useState(currentPage);
|
||||
const { isOpen, onToggle, onClose } = useDisclosure();
|
||||
const ref = useRef<HTMLInputElement>(null);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
setNewPage(currentPage);
|
||||
setTimeout(() => {
|
||||
const input = ref.current?.querySelector('input');
|
||||
input?.focus();
|
||||
input?.select();
|
||||
}, 0);
|
||||
}, [currentPage]);
|
||||
|
||||
const onChangeJumpTo = useCallback((v: number) => {
|
||||
setNewPage(v - 1);
|
||||
}, []);
|
||||
|
||||
const onClickGo = useCallback(() => {
|
||||
goToPage(newPage);
|
||||
onClose();
|
||||
}, [newPage, goToPage, onClose]);
|
||||
|
||||
useHotkeys(
|
||||
'enter',
|
||||
() => {
|
||||
onClickGo();
|
||||
},
|
||||
{ enabled: isOpen, enableOnFormTags: ['input'] },
|
||||
[isOpen, onClickGo]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'esc',
|
||||
() => {
|
||||
setNewPage(currentPage);
|
||||
onClose();
|
||||
},
|
||||
{ enabled: isOpen, enableOnFormTags: ['input'] },
|
||||
[isOpen, onClose]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setNewPage(currentPage);
|
||||
}, [currentPage]);
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
|
||||
<PopoverTrigger>
|
||||
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
|
||||
{t('gallery.jump')}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<FormControl>
|
||||
<CompositeNumberInput
|
||||
ref={ref}
|
||||
size="sm"
|
||||
maxW="60px"
|
||||
value={newPage + 1}
|
||||
min={1}
|
||||
max={pages}
|
||||
step={1}
|
||||
onChange={onChangeJumpTo}
|
||||
/>
|
||||
</FormControl>
|
||||
<Button h="full" size="sm" onClick={onClickGo}>
|
||||
{t('gallery.go')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { throttle } from 'lodash-es';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
@@ -81,41 +80,32 @@ export const useGalleryPagination = () => {
|
||||
return offset > 0;
|
||||
}, [count, offset]);
|
||||
|
||||
const onOffsetChanged = useCallback(
|
||||
(arg: Parameters<typeof offsetChanged>[0]) => {
|
||||
dispatch(offsetChanged(arg));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const throttledOnOffsetChanged = useMemo(() => throttle(onOffsetChanged, 500), [onOffsetChanged]);
|
||||
|
||||
const goNext = useCallback(
|
||||
(withHotkey?: 'arrow' | 'alt+arrow') => {
|
||||
throttledOnOffsetChanged({ offset: offset + (limit || 0), withHotkey });
|
||||
dispatch(offsetChanged({ offset: offset + (limit || 0), withHotkey }));
|
||||
},
|
||||
[throttledOnOffsetChanged, offset, limit]
|
||||
[dispatch, offset, limit]
|
||||
);
|
||||
|
||||
const goPrev = useCallback(
|
||||
(withHotkey?: 'arrow' | 'alt+arrow') => {
|
||||
throttledOnOffsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey });
|
||||
dispatch(offsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey }));
|
||||
},
|
||||
[throttledOnOffsetChanged, offset, limit]
|
||||
[dispatch, offset, limit]
|
||||
);
|
||||
|
||||
const goToPage = useCallback(
|
||||
(page: number) => {
|
||||
throttledOnOffsetChanged({ offset: page * (limit || 0) });
|
||||
dispatch(offsetChanged({ offset: page * (limit || 0) }));
|
||||
},
|
||||
[throttledOnOffsetChanged, limit]
|
||||
[dispatch, limit]
|
||||
);
|
||||
const goToFirst = useCallback(() => {
|
||||
throttledOnOffsetChanged({ offset: 0 });
|
||||
}, [throttledOnOffsetChanged]);
|
||||
dispatch(offsetChanged({ offset: 0 }));
|
||||
}, [dispatch]);
|
||||
const goToLast = useCallback(() => {
|
||||
throttledOnOffsetChanged({ offset: (pages - 1) * (limit || 0) });
|
||||
}, [throttledOnOffsetChanged, pages, limit]);
|
||||
dispatch(offsetChanged({ offset: (pages - 1) * (limit || 0) }));
|
||||
}, [dispatch, pages, limit]);
|
||||
|
||||
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
||||
modelKey ?? skipToken,
|
||||
isControlNetOrT2IAdapterModelConfig
|
||||
);
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||
) => {
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
preprocessor: {
|
||||
@@ -14,5 +19,5 @@ export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
};
|
||||
}, [modelConfig?.default_settings]);
|
||||
|
||||
return defaultSettingsDefaults;
|
||||
return { defaultSettingsDefaults, isLoading };
|
||||
};
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||
@@ -19,7 +22,9 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
||||
};
|
||||
});
|
||||
|
||||
export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
||||
|
||||
const {
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
@@ -76,5 +81,5 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||
initialHeight,
|
||||
]);
|
||||
|
||||
return defaultSettingsDefaults;
|
||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
@@ -50,8 +50,6 @@ export const modelManagerV2Slice = createSlice({
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||
modelManagerV2Slice.actions;
|
||||
|
||||
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
|
||||
export const HuggingFaceForm = memo(() => {
|
||||
export const HuggingFaceForm = () => {
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||
const [displayResults, setDisplayResults] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@@ -66,6 +66,4 @@ export const HuggingFaceForm = memo(() => {
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceForm.displayName = 'HuggingFaceForm';
|
||||
};
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
result: string;
|
||||
};
|
||||
export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [installModel] = useInstallModel();
|
||||
@@ -27,6 +27,4 @@ export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceResultItem.displayName = 'HuggingFaceResultItem';
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
@@ -21,7 +21,7 @@ type HuggingFaceResultsProps = {
|
||||
results: string[];
|
||||
};
|
||||
|
||||
export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) => {
|
||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
@@ -93,6 +93,4 @@ export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) =>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceResults.displayName = 'HuggingFaceResults';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
|
||||
@@ -10,7 +10,7 @@ type SimpleImportModelConfig = {
|
||||
inplace: boolean;
|
||||
};
|
||||
|
||||
export const InstallModelForm = memo(() => {
|
||||
export const InstallModelForm = () => {
|
||||
const [installModel, { isLoading }] = useInstallModel();
|
||||
|
||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||
@@ -74,6 +74,4 @@ export const InstallModelForm = memo(() => {
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
});
|
||||
|
||||
InstallModelForm.displayName = 'InstallModelForm';
|
||||
};
|
||||
|
||||
@@ -2,12 +2,12 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ModelInstallQueue = memo(() => {
|
||||
export const ModelInstallQueue = () => {
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
@@ -61,6 +61,4 @@ export const ModelInstallQueue = memo(() => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelInstallQueue.displayName = 'ModelInstallQueue';
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { ModelInstallJob } from 'services/api/types';
|
||||
@@ -25,7 +25,7 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
@@ -124,9 +124,7 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelInstallQueueItem.displayName = 'ModelInstallQueueItem';
|
||||
};
|
||||
|
||||
type TooltipLabelProps = {
|
||||
installJob: ModelInstallJob;
|
||||
@@ -134,7 +132,7 @@ type TooltipLabelProps = {
|
||||
source: string;
|
||||
};
|
||||
|
||||
const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
@@ -158,6 +156,4 @@ const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipLabel.displayName = 'TooltipLabel';
|
||||
};
|
||||
|
||||
@@ -2,13 +2,13 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel,
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = memo(() => {
|
||||
export const ScanModelsForm = () => {
|
||||
const scanPath = useAppSelector((state) => state.modelmanagerV2.scanPath);
|
||||
const dispatch = useAppDispatch();
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@@ -56,6 +56,4 @@ export const ScanModelsForm = memo(() => {
|
||||
{data && <ScanModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelsForm.displayName = 'ScanModelsForm';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@@ -8,7 +8,7 @@ type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
installModel: (source: string) => void;
|
||||
};
|
||||
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInstall = useCallback(() => {
|
||||
@@ -30,6 +30,4 @@ export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelResultItem.displayName = 'ScanModelResultItem';
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@@ -25,7 +25,7 @@ type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
||||
};
|
||||
|
||||
export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [inplace, setInplace] = useState(true);
|
||||
@@ -116,6 +116,4 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelsResults.displayName = 'ScanModelsResults';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@@ -9,7 +9,7 @@ import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
type Props = {
|
||||
result: GetStarterModelsResponse[number];
|
||||
};
|
||||
export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||
export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const allSources = useMemo(() => {
|
||||
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
|
||||
@@ -47,6 +47,4 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsResultItem.displayName = 'StarterModelsResultItem';
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||
import { memo } from 'react';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResults } from './StarterModelsResults';
|
||||
|
||||
export const StarterModelsForm = memo(() => {
|
||||
export const StarterModelsForm = () => {
|
||||
const { isLoading, data } = useGetStarterModelsQuery();
|
||||
|
||||
return (
|
||||
@@ -14,6 +13,4 @@ export const StarterModelsForm = memo(() => {
|
||||
{data && <StarterModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsForm.displayName = 'StarterModelsForm';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@@ -12,7 +12,7 @@ type StarterModelsResultsProps = {
|
||||
results: NonNullable<GetStarterModelsResponse>;
|
||||
};
|
||||
|
||||
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
|
||||
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
@@ -79,6 +79,4 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsResults.displayName = 'StarterModelsResults';
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@in
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
@@ -12,7 +12,7 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const $installModelsTab = atom(0);
|
||||
|
||||
export const InstallModels = memo(() => {
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
const index = useStore($installModelsTab);
|
||||
const onChange = useCallback((index: number) => {
|
||||
@@ -49,6 +49,4 @@ export const InstallModels = memo(() => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
InstallModels.displayName = 'InstallModels';
|
||||
};
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||
|
||||
export const ModelManager = memo(() => {
|
||||
export const ModelManager = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const handleClickAddModel = useCallback(() => {
|
||||
@@ -29,6 +29,4 @@ export const ModelManager = memo(() => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelManager.displayName = 'ModelManager';
|
||||
};
|
||||
|
||||
@@ -21,8 +21,7 @@ import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -24,21 +23,15 @@ const sx: SystemStyleObject = {
|
||||
"&[aria-selected='true']": { bg: 'base.700' },
|
||||
};
|
||||
|
||||
const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
const ModelListItem = (props: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectIsSelected = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectModelManagerV2Slice,
|
||||
(modelManagerV2Slice) => modelManagerV2Slice.selectedModelKey === model.key
|
||||
),
|
||||
[model.key]
|
||||
);
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const { model } = props;
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
dispatch(setSelectedModelKey(model.key));
|
||||
}, [model.key, dispatch]);
|
||||
@@ -50,6 +43,11 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
|
||||
const isSelected = useMemo(() => {
|
||||
return selectedModelKey === model.key;
|
||||
}, [selectedModelKey, model.key]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
|
||||
@@ -3,12 +3,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { t } from 'i18next';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||
|
||||
export const ModelListNavigation = memo(() => {
|
||||
export const ModelListNavigation = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
|
||||
@@ -49,6 +49,4 @@ export const ModelListNavigation = memo(() => {
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelListNavigation.displayName = 'ModelListNavigation';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
||||
import { memo } from 'react';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
@@ -9,7 +8,7 @@ type ModelListWrapperProps = {
|
||||
modelList: AnyModelConfig[];
|
||||
};
|
||||
|
||||
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
||||
const { title, modelList } = props;
|
||||
return (
|
||||
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
|
||||
@@ -18,6 +17,4 @@ export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
))}
|
||||
</StickyScrollable>
|
||||
);
|
||||
});
|
||||
|
||||
ModelListWrapper.displayName = 'ModelListWrapper';
|
||||
};
|
||||
|
||||
@@ -2,12 +2,12 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-libr
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFunnelBold } from 'react-icons/pi';
|
||||
import { objectKeys } from 'tsafe';
|
||||
|
||||
export const ModelTypeFilter = memo(() => {
|
||||
export const ModelTypeFilter = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
@@ -57,6 +57,4 @@ export const ModelTypeFilter = memo(() => {
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
ModelTypeFilter.displayName = 'ModelTypeFilter';
|
||||
};
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = memo(() => {
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
ModelPane.displayName = 'ModelPane';
|
||||
};
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
||||
preprocessor: FormField<string>;
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
|
||||
};
|
||||
|
||||
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
|
||||
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
|
||||
useControlNetOrT2IAdapterDefaultSettings(selectedModelKey);
|
||||
|
||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||
|
||||
@@ -32,12 +30,16 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
||||
};
|
||||
|
||||
updateModel({
|
||||
key: modelConfig.key,
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
@@ -59,9 +61,13 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
}
|
||||
});
|
||||
},
|
||||
[updateModel, modelConfig.key, t, reset]
|
||||
[selectedModelKey, reset, updateModel, t]
|
||||
);
|
||||
|
||||
if (isLoadingDefaultSettings) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||
@@ -83,6 +89,4 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
||||
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -28,7 +28,7 @@ const OPTIONS = [
|
||||
|
||||
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
|
||||
|
||||
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
|
||||
export function DefaultPreprocessor(props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -63,6 +63,4 @@ export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT
|
||||
<Combobox isDisabled={isDisabled} value={value} options={OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultPreprocessor.displayName = 'DefaultPreprocessor';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = MainModelDefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export const DefaultCfgRescaleMultiplier = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultCfgRescaleMultiplier = memo((props: UseControllerProps<MainM
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultCfgRescaleMultiplier.displayName = 'DefaultCfgRescaleMultiplier';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultCfgType = MainModelDefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export const DefaultCfgScale = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultCfgScale(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultCfgScale = memo((props: UseControllerProps<MainModelDefaultS
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultCfgScale.displayName = 'DefaultCfgScale';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -16,7 +16,7 @@ type Props = {
|
||||
optimalDimension: number;
|
||||
};
|
||||
|
||||
export const DefaultHeight = memo(({ control, optimalDimension }: Props) => {
|
||||
export function DefaultHeight({ control, optimalDimension }: Props) {
|
||||
const { field } = useController({ control, name: 'height' });
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
|
||||
@@ -78,6 +78,4 @@ export const DefaultHeight = memo(({ control, optimalDimension }: Props) => {
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultHeight.displayName = 'DefaultHeight';
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -13,7 +13,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultSchedulerType = MainModelDefaultSettingsFormData['scheduler'];
|
||||
|
||||
export const DefaultScheduler = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultScheduler(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -51,6 +51,4 @@ export const DefaultScheduler = memo((props: UseControllerProps<MainModelDefault
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultScheduler.displayName = 'DefaultScheduler';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultSteps = MainModelDefaultSettingsFormData['steps'];
|
||||
|
||||
export const DefaultSteps = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultSteps(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultSteps = memo((props: UseControllerProps<MainModelDefaultSett
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultSteps.displayName = 'DefaultSteps';
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -15,7 +15,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultVaeType = MainModelDefaultSettingsFormData['vae'];
|
||||
|
||||
export const DefaultVae = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
@@ -64,6 +64,4 @@ export const DefaultVae = memo((props: UseControllerProps<MainModelDefaultSettin
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultVae.displayName = 'DefaultVae';
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -17,7 +17,7 @@ const options = [
|
||||
|
||||
type DefaultVaePrecisionType = MainModelDefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export const DefaultVaePrecision = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultVaePrecision(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -52,6 +52,4 @@ export const DefaultVaePrecision = memo((props: UseControllerProps<MainModelDefa
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultVaePrecision.displayName = 'DefaultVaePrecision';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -16,7 +16,7 @@ type Props = {
|
||||
optimalDimension: number;
|
||||
};
|
||||
|
||||
export const DefaultWidth = memo(({ control, optimalDimension }: Props) => {
|
||||
export function DefaultWidth({ control, optimalDimension }: Props) {
|
||||
const { field } = useController({ control, name: 'width' });
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
|
||||
@@ -78,6 +78,4 @@ export const DefaultWidth = memo(({ control, optimalDimension }: Props) => {
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultWidth.displayName = 'DefaultWidth';
|
||||
}
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
@@ -37,16 +35,16 @@ export type MainModelDefaultSettingsFormData = {
|
||||
height: FormField<number>;
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: MainModelConfig;
|
||||
};
|
||||
|
||||
export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
export const MainModelDefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
|
||||
const optimalDimension = useMemo(() => getOptimalDimension(modelConfig), [modelConfig]);
|
||||
const {
|
||||
defaultSettingsDefaults,
|
||||
isLoading: isLoadingDefaultSettings,
|
||||
optimalDimension,
|
||||
} = useMainModelDefaultSettings(selectedModelKey);
|
||||
|
||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
||||
@@ -96,6 +94,10 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
[selectedModelKey, reset, updateModel, t]
|
||||
);
|
||||
|
||||
if (isLoadingDefaultSettings) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||
@@ -124,6 +126,4 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
MainModelDefaultSettings.displayName = 'MainModelDefaultSettings';
|
||||
};
|
||||
|
||||
@@ -1,47 +1,120 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = memo(() => {
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelConfigs, isLoading } = useGetModelConfigsQuery();
|
||||
const modelConfig = useMemo(() => {
|
||||
if (!modelConfigs) {
|
||||
return null;
|
||||
}
|
||||
if (selectedModelKey === null) {
|
||||
return null;
|
||||
}
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
}
|
||||
const form = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: data,
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
return modelConfig;
|
||||
}, [modelConfigs, selectedModelKey]);
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
}
|
||||
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: data.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
form.reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
toast({
|
||||
id: 'MODEL_UPDATED',
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((_) => {
|
||||
form.reset();
|
||||
toast({
|
||||
id: 'MODEL_UPDATE_FAILED',
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
},
|
||||
[dispatch, data?.key, form, t, updateModel]
|
||||
);
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner label={t('common.loading')} />;
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!modelConfig) {
|
||||
return <IAINoContentFallback label={t('common.somethingWentWrong')} icon={PiExclamationMarkBold} />;
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
if (selectedModelMode === 'view') {
|
||||
return <ModelView modelConfig={modelConfig} />;
|
||||
}
|
||||
|
||||
return <ModelEdit modelConfig={modelConfig} />;
|
||||
});
|
||||
|
||||
Model.displayName = 'Model';
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex alignItems="flex-start" gap={4}>
|
||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||
<Flex gap={2}>
|
||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||
{data.name}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
|
||||
{selectedModelMode === 'view' && <ModelEditButton />}
|
||||
{selectedModelMode === 'edit' && (
|
||||
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
)}
|
||||
{selectedModelMode === 'edit' && (
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
{data.source && (
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Text noOfLines={3}>{data.description}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
interface Props {
|
||||
label: string;
|
||||
value: string | null | undefined;
|
||||
}
|
||||
|
||||
export const ModelAttrView = memo(({ label, value }: Props) => {
|
||||
export const ModelAttrView = ({ label, value }: Props) => {
|
||||
return (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
@@ -15,6 +14,4 @@ export const ModelAttrView = memo(({ label, value }: Props) => {
|
||||
</Text>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ModelAttrView.displayName = 'ModelAttrView';
|
||||
};
|
||||
|
||||
@@ -8,46 +8,52 @@ import {
|
||||
UnorderedList,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
interface ModelConvertProps {
|
||||
modelConfig: CheckpointModelConfig;
|
||||
modelKey: string | null;
|
||||
}
|
||||
|
||||
export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
export const ModelConvertButton = (props: ModelConvertProps) => {
|
||||
const { modelKey } = props;
|
||||
const { t } = useTranslation();
|
||||
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
|
||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
if (!modelConfig || isLoading) {
|
||||
if (!data || isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
const toastId = `CONVERTING_MODEL_${modelConfig.key}`;
|
||||
const toastId = `CONVERTING_MODEL_${data.key}`;
|
||||
toast({
|
||||
id: toastId,
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${modelConfig.name}`,
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
|
||||
status: 'info',
|
||||
});
|
||||
|
||||
convertModel(modelConfig.key)
|
||||
convertModel(data?.key)
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${modelConfig.name}`, status: 'success' });
|
||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' });
|
||||
})
|
||||
.catch(() => {
|
||||
toast({
|
||||
id: toastId,
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${modelConfig.name}`,
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
}, [modelConfig, isLoading, t, convertModel]);
|
||||
}, [data, isLoading, t, convertModel]);
|
||||
|
||||
if (data?.format !== 'checkpoint') {
|
||||
return;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -62,7 +68,7 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
🧨 {t('modelManager.convert')}
|
||||
</Button>
|
||||
<ConfirmationAlertDialog
|
||||
title={`${t('modelManager.convert')} ${modelConfig.name}`}
|
||||
title={`${t('modelManager.convert')} ${data?.name}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
acceptButtonText={`${t('modelManager.convert')}`}
|
||||
isOpen={isOpen}
|
||||
@@ -90,6 +96,4 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
</ConfirmationAlertDialog>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ModelConvertButton.displayName = 'ModelConvertButton';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
@@ -8,154 +7,96 @@ import {
|
||||
Heading,
|
||||
Input,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { type SubmitHandler, useForm } from 'react-hook-form';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
form: UseFormReturn<UpdateModelArg['body']>;
|
||||
onSubmit: SubmitHandler<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const stringFieldOptions = {
|
||||
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
};
|
||||
|
||||
export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
export const ModelEdit = ({ form }: Props) => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { t } = useTranslation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const form = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: modelConfig,
|
||||
mode: 'onChange',
|
||||
});
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: modelConfig.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
form.reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
toast({
|
||||
id: 'MODEL_UPDATED',
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((_) => {
|
||||
form.reset();
|
||||
toast({
|
||||
id: 'MODEL_UPDATE_FAILED',
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
},
|
||||
[dispatch, modelConfig.key, form, t, updateModel]
|
||||
);
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
}, [dispatch]);
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
<Button flexShrink={0} size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</ModelHeader>
|
||||
<Flex flexDir="column" h="full">
|
||||
<form>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl
|
||||
flexDir="column"
|
||||
alignItems="flex-start"
|
||||
gap={1}
|
||||
isInvalid={Boolean(form.formState.errors.name)}
|
||||
>
|
||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||
<Flex flexDir="column" h="full">
|
||||
<form>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
|
||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||
|
||||
{form.formState.errors.name?.message && (
|
||||
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||
)}
|
||||
{form.formState.errors.name?.message && (
|
||||
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||
)}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex gap="4" alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea {...form.register('description')} minH={32} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex gap="4" alignItems="center">
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea {...form.register('description')} minH={32} />
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{modelConfig.type === 'main' && (
|
||||
)}
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelEdit.displayName = 'ModelEdit';
|
||||
};
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
|
||||
export const ModelEditButton = memo(() => {
|
||||
export const ModelEditButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -18,6 +18,4 @@ export const ModelEditButton = memo(() => {
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
ModelEditButton.displayName = 'ModelEditButton';
|
||||
};
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
modelConfig: AnyModelConfig;
|
||||
}>;
|
||||
|
||||
export const ModelHeader = memo(({ modelConfig, children }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex alignItems="flex-start" gap={4}>
|
||||
<ModelImageUpload model_key={modelConfig.key} model_image={modelConfig.cover_image} />
|
||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||
<Flex gap={2}>
|
||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||
{modelConfig.name}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
{children}
|
||||
</Flex>
|
||||
{modelConfig.source && (
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
{t('modelManager.source')}: {modelConfig.source}
|
||||
</Text>
|
||||
)}
|
||||
<Text noOfLines={3}>{modelConfig.description}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelHeader.displayName = 'ModelHeader';
|
||||
@@ -1,67 +1,55 @@
|
||||
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
};
|
||||
|
||||
export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
export const ModelView = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||
<ModelConvertButton modelConfig={modelConfig} />
|
||||
)}
|
||||
<ModelEditButton />
|
||||
</ModelHeader>
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
|
||||
<ModelAttrView label={t('common.format')} value={modelConfig.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
|
||||
{modelConfig.type === 'main' && (
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelConfig.repo_variant} />
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelConfig.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelConfig.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelConfig.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
{modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && (
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelConfig.image_encoder_model_id} />
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
|
||||
</Box>
|
||||
</Flex>
|
||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelView.displayName = 'ModelView';
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Switch, typedMemo } from '@invoke-ai/ui-library';
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
@@ -6,7 +6,7 @@ import { useController } from 'react-hook-form';
|
||||
|
||||
import type { FormField } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
|
||||
export const SettingToggle = typedMemo(<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) => {
|
||||
export function SettingToggle<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
@@ -25,6 +25,4 @@ export const SettingToggle = typedMemo(<T, F extends Record<string, FormField<T>
|
||||
);
|
||||
|
||||
return <Switch size="sm" isChecked={value} onChange={onChange} />;
|
||||
});
|
||||
|
||||
SettingToggle.displayName = 'SettingToggle';
|
||||
}
|
||||
|
||||
@@ -9,19 +9,19 @@ import {
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig, MainModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import { isLoRAModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelConfig: MainModelConfig | LoRAModelConfig;
|
||||
};
|
||||
|
||||
export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { currentData: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
@@ -31,6 +31,9 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
if (!modelConfig || (!isNonRefinerMainModelConfig(modelConfig) && !isLoRAModelConfig(modelConfig))) {
|
||||
return [];
|
||||
}
|
||||
return modelConfig?.trigger_phrases || [];
|
||||
}, [modelConfig]);
|
||||
|
||||
@@ -45,6 +48,10 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
@@ -52,18 +59,22 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
setPhrase('');
|
||||
|
||||
await updateModel({
|
||||
key: modelConfig.key,
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
}, [phrase, triggerPhrases, updateModel, modelConfig.key]);
|
||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await updateModel({ key: modelConfig.key, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[triggerPhrases, updateModel, modelConfig]
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||
@@ -92,9 +103,7 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{errors.map((error) => (
|
||||
<FormErrorMessage key={error}>{error}</FormErrorMessage>
|
||||
))}
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
@@ -109,6 +118,4 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TriggerPhrases.displayName = 'TriggerPhrases';
|
||||
};
|
||||
|
||||
@@ -59,19 +59,17 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
for (const edge of copiedEdges) {
|
||||
if (edge.source === node.id) {
|
||||
edge.source = id;
|
||||
} else if (edge.target === node.id) {
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
if (edge.target === node.id) {
|
||||
edge.target = id;
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
}
|
||||
node.id = id;
|
||||
node.data.id = id;
|
||||
});
|
||||
|
||||
copiedEdges.forEach((edge) => {
|
||||
// Copied edges need a fresh id too
|
||||
edge.id = uuidv4();
|
||||
});
|
||||
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
// Deselect existing nodes
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { creativityChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -26,9 +25,7 @@ const ParamCreativity = () => {
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="creativity">
|
||||
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={creativity}
|
||||
defaultValue={initial}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -38,9 +37,7 @@ const ParamSpandrelModel = () => {
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<InformationalPopover feature="upscaleModel">
|
||||
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box w="full">
|
||||
<Combobox
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { structureChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -26,9 +25,7 @@ const ParamStructure = () => {
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="structure">
|
||||
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={structure}
|
||||
defaultValue={initial}
|
||||
|
||||
@@ -64,7 +64,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
const badges = useAppSelector(selectBadges);
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||
id: `'advanced-settings-${activeTabName}`,
|
||||
id: 'advanced-settings',
|
||||
defaultIsOpen: false,
|
||||
});
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { filter } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -27,7 +26,6 @@ const formLabelProps: FormLabelProps = {
|
||||
export const GenerationSettingsAccordion = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
@@ -44,8 +42,8 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
defaultIsOpen: false,
|
||||
});
|
||||
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
|
||||
id: `generation-settings-${activeTabName}`,
|
||||
defaultIsOpen: activeTabName !== 'upscaling',
|
||||
id: 'generation-settings',
|
||||
defaultIsOpen: true,
|
||||
});
|
||||
|
||||
return (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user