mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 20:17:59 -05:00
Compare commits
3 Commits
v4.2.7
...
ryan/cloth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bead83a8bc | ||
|
|
3573d39860 | ||
|
|
36d72baaaa |
158
clothing_workflow.ipynb
Normal file
158
clothing_workflow.ipynb
Normal file
@@ -0,0 +1,158 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aeb428d0-0817-462c-b5d8-455a0615d305",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"import numpy as np\n",
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask, multiply_images\n",
|
||||
"from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n",
|
||||
"from invokeai.backend.vto_workflow.seamless_mapping import map_seamless_tiles\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6140d4b7-8238-431c-848e-6f6ae27652f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" # Load the model image.\n",
|
||||
"model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n",
|
||||
"\n",
|
||||
"# Load the pattern image.\n",
|
||||
"pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fb7186ba-dc0c-4520-ac30-49073a65601a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mask = generate_dress_mask(model_image)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize mask\n",
|
||||
"model_image_np = np.array(model_image)\n",
|
||||
"masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n",
|
||||
"mask_image = Image.fromarray(masked_model_image)\n",
|
||||
"mask_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e51bb545",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec43de4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize masked shadows\n",
|
||||
"masked_shadows = (shadows * mask).astype(np.uint8)\n",
|
||||
"masked_shadows_image = Image.fromarray(masked_shadows)\n",
|
||||
"masked_shadows_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dbb53794",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Tile the pattern.\n",
|
||||
"expanded_pattern = map_seamless_tiles(seamless_tile=pattern_image, target_hw=(model_image.height, model_image.width), num_repeats_h=10.0)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f4f22d02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Multiply the pattern by the shadows.\n",
|
||||
"pattern_with_shadows = multiply_images(expanded_pattern, shadows)\n",
|
||||
"pattern_with_shadows"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97db42b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de32f7e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Merge the pattern with the model image.\n",
|
||||
"pattern_with_shadows_np = np.array(pattern_with_shadows)\n",
|
||||
"merged_image = np.where(mask[:, :, None], pattern_with_shadows_np,model_image_np)\n",
|
||||
"merged_image = Image.fromarray(merged_image)\n",
|
||||
"merged_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff1d4044",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -55,7 +55,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -98,9 +98,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
try:
|
||||
|
||||
@@ -187,171 +187,164 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster v2 (SD1.5)",
|
||||
name="QRCode Monster",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="QRCode Monster (SDXL)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
|
||||
0
invokeai/backend/vto_workflow/__init__.py
Normal file
0
invokeai/backend/vto_workflow/__init__.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
|
||||
|
||||
|
||||
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
|
||||
# Load the model.
|
||||
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
return clipseg_processor, clipseg_model
|
||||
|
||||
|
||||
def run_clipseg(
|
||||
images: list[Image.Image],
|
||||
prompt: str,
|
||||
clipseg_processor,
|
||||
clipseg_model,
|
||||
clipseg_temp: float,
|
||||
device: torch.device,
|
||||
) -> list[Image.Image]:
|
||||
"""Run ClipSeg on a list of images.
|
||||
|
||||
Args:
|
||||
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
|
||||
and include more of the background. Recommended range: 0.5 to 1.0.
|
||||
"""
|
||||
|
||||
orig_image_sizes = [img.size for img in images]
|
||||
|
||||
prompts = [prompt] * len(images)
|
||||
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
|
||||
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
|
||||
|
||||
# Move inputs and clipseg_model to the correct device and dtype.
|
||||
inputs = {k: v.to(device=device) for k, v in inputs.items()}
|
||||
clipseg_model = clipseg_model.to(device=device)
|
||||
|
||||
outputs = clipseg_model(**inputs)
|
||||
|
||||
logits = outputs.logits
|
||||
if logits.ndim == 2:
|
||||
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
|
||||
logits = logits.unsqueeze(0)
|
||||
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
|
||||
# Normalize each mask to 0-255. Note that each mask is normalized independently.
|
||||
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Make mask greyscale.
|
||||
masks: list[Image.Image] = []
|
||||
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
|
||||
mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L")
|
||||
mask = mask.resize(orig_size)
|
||||
masks.append(mask)
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def select_device() -> torch.device:
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
return torch.device("cpu")
|
||||
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class ImageChannel(Enum):
|
||||
RGB_R = "RGB_R"
|
||||
RGB_G = "RGB_G"
|
||||
RGB_B = "RGB_B"
|
||||
|
||||
LAB_L = "LAB_L"
|
||||
LAB_A = "LAB_A"
|
||||
LAB_B = "LAB_B"
|
||||
|
||||
HSV_H = "HSV_H"
|
||||
HSV_S = "HSV_S"
|
||||
HSV_V = "HSV_V"
|
||||
|
||||
|
||||
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
|
||||
"""Extract a channel from an image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
|
||||
channel (ImageChannel): The channel to extract.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (H, W) of dtype uint8.
|
||||
"""
|
||||
if channel == ImageChannel.RGB_R:
|
||||
return image[:, :, 0]
|
||||
elif channel == ImageChannel.RGB_G:
|
||||
return image[:, :, 1]
|
||||
elif channel == ImageChannel.RGB_B:
|
||||
return image[:, :, 2]
|
||||
elif channel == ImageChannel.LAB_L:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 0]
|
||||
elif channel == ImageChannel.LAB_A:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 1]
|
||||
elif channel == ImageChannel.LAB_B:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 2]
|
||||
elif channel == ImageChannel.HSV_H:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 0]
|
||||
elif channel == ImageChannel.HSV_S:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 1]
|
||||
elif channel == ImageChannel.HSV_V:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 2]
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {channel}")
|
||||
71
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
71
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_dress_mask(model_image):
|
||||
"""Return a mask of the dress in the image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere.
|
||||
"""
|
||||
clipseg_processor, clipseg_model = load_clipseg_model()
|
||||
|
||||
masks = run_clipseg(
|
||||
images=[model_image],
|
||||
prompt="a dress",
|
||||
clipseg_processor=clipseg_processor,
|
||||
clipseg_model=clipseg_model,
|
||||
clipseg_temp=1.0,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
mask_np = np.array(masks[0])
|
||||
thresh = 128
|
||||
binary_mask = mask_np > thresh
|
||||
return binary_mask
|
||||
|
||||
|
||||
def multiply_images(image_1: Image.Image, image_2: Image.Image) -> Image.Image:
|
||||
"""Multiply two images together.
|
||||
|
||||
Args:
|
||||
image_1 (Image.Image): The first image.
|
||||
image_2 (Image.Image): The second image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The product of the two images.
|
||||
"""
|
||||
image_1_np = np.array(image_1, dtype=np.float32)
|
||||
if image_1_np.ndim == 2:
|
||||
# If the image is greyscale, add a channel dimension.
|
||||
image_1_np = np.expand_dims(image_1_np, axis=-1)
|
||||
image_2_np = np.array(image_2, dtype=np.float32)
|
||||
if image_2_np.ndim == 2:
|
||||
# If the image is greyscale, add a channel dimension.
|
||||
image_2_np = np.expand_dims(image_2_np, axis=-1)
|
||||
product_np = image_1_np * image_2_np // 255
|
||||
product_np = np.clip(product_np, 0, 255).astype(np.uint8)
|
||||
product = Image.fromarray(product_np)
|
||||
return product
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Load the model image.
|
||||
model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg")
|
||||
|
||||
# Load the pattern image.
|
||||
pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg")
|
||||
|
||||
# Generate a mask for the dress.
|
||||
mask = generate_dress_mask(model_image)
|
||||
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
invokeai/backend/vto_workflow/seamless_mapping.py
Normal file
31
invokeai/backend/vto_workflow/seamless_mapping.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def map_seamless_tiles(seamless_tile: Image.Image, target_hw: tuple[int, int], num_repeats_h: float) -> Image.Image:
|
||||
"""Map seamless tiles to a target size with a given number of repeats along the height dimension."""
|
||||
# TODO(ryand): Add option to flip odd rows and columns if the tile is not seamless.
|
||||
# - May also want the option to decide on a per-axis basis.
|
||||
|
||||
target_h, target_w = target_hw
|
||||
|
||||
# Calculate the height of the tile that is necessary to achieve the desired number of repeats.
|
||||
# Take the ceiling so that the last tile overflows the target height.
|
||||
target_tile_h = math.ceil(target_h / num_repeats_h)
|
||||
|
||||
# Resize the tile to the target height.
|
||||
# Determine the target_tile_w that preserves the original aspect ratio.
|
||||
target_tile_w = int(target_tile_h / seamless_tile.height * seamless_tile.width)
|
||||
resized_tile = seamless_tile.resize((target_tile_w, target_tile_h))
|
||||
|
||||
# Repeat the tile along the height and width dimensions.
|
||||
num_repeats_h_int = math.ceil(num_repeats_h)
|
||||
num_repeats_w_int = math.ceil(target_w / target_tile_w)
|
||||
seamless_tiles_np = np.array(resized_tile)
|
||||
repeated_tiles_np = np.tile(seamless_tiles_np, (num_repeats_h_int, num_repeats_w_int, 1))
|
||||
|
||||
# Crop the repeated tiles to the target size.
|
||||
output_pattern = Image.fromarray(repeated_tiles_np[:target_h, :target_w])
|
||||
return output_pattern
|
||||
@@ -1509,30 +1509,6 @@
|
||||
"seamlessTilingYAxis": {
|
||||
"heading": "Seamless Tiling Y Axis",
|
||||
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
|
||||
},
|
||||
"upscaleModel": {
|
||||
"heading": "Upscale Model",
|
||||
"paragraphs": [
|
||||
"The upscale model scales the image to the output size before details are added. Any supported upscale model may be used, but some are specialized for different kinds of images, like photos or line drawings."
|
||||
]
|
||||
},
|
||||
"scale": {
|
||||
"heading": "Scale",
|
||||
"paragraphs": [
|
||||
"Scale controls the output image size, and is based on a multiple of the input image resolution. For example a 2x upscale on a 1024x1024 image would produce a 2048 x 2048 output."
|
||||
]
|
||||
},
|
||||
"creativity": {
|
||||
"heading": "Creativity",
|
||||
"paragraphs": [
|
||||
"Creativity controls the amount of freedom granted to the model when adding details. Low creativity stays close to the original image, while high creativity allows for more change. When using a prompt, high creativity increases the influence of the prompt."
|
||||
]
|
||||
},
|
||||
"structure": {
|
||||
"heading": "Structure",
|
||||
"paragraphs": [
|
||||
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||
]
|
||||
}
|
||||
},
|
||||
"unifiedCanvas": {
|
||||
|
||||
@@ -53,11 +53,7 @@ export type Feature =
|
||||
| 'refinerCfgScale'
|
||||
| 'scaleBeforeProcessing'
|
||||
| 'seamlessTilingXAxis'
|
||||
| 'seamlessTilingYAxis'
|
||||
| 'upscaleModel'
|
||||
| 'scale'
|
||||
| 'creativity'
|
||||
| 'structure';
|
||||
| 'seamlessTilingYAxis';
|
||||
|
||||
export type PopoverData = PopoverProps & {
|
||||
image?: string;
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
||||
modelKey ?? skipToken,
|
||||
isControlNetOrT2IAdapterModelConfig
|
||||
);
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||
) => {
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
preprocessor: {
|
||||
@@ -14,5 +19,5 @@ export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
};
|
||||
}, [modelConfig?.default_settings]);
|
||||
|
||||
return defaultSettingsDefaults;
|
||||
return { defaultSettingsDefaults, isLoading };
|
||||
};
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||
@@ -19,7 +22,9 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
||||
};
|
||||
});
|
||||
|
||||
export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
||||
|
||||
const {
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
@@ -76,5 +81,5 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||
initialHeight,
|
||||
]);
|
||||
|
||||
return defaultSettingsDefaults;
|
||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
@@ -50,8 +50,6 @@ export const modelManagerV2Slice = createSlice({
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||
modelManagerV2Slice.actions;
|
||||
|
||||
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
|
||||
export const HuggingFaceForm = memo(() => {
|
||||
export const HuggingFaceForm = () => {
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||
const [displayResults, setDisplayResults] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@@ -66,6 +66,4 @@ export const HuggingFaceForm = memo(() => {
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceForm.displayName = 'HuggingFaceForm';
|
||||
};
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
result: string;
|
||||
};
|
||||
export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [installModel] = useInstallModel();
|
||||
@@ -27,6 +27,4 @@ export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceResultItem.displayName = 'HuggingFaceResultItem';
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
@@ -21,7 +21,7 @@ type HuggingFaceResultsProps = {
|
||||
results: string[];
|
||||
};
|
||||
|
||||
export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) => {
|
||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
@@ -93,6 +93,4 @@ export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) =>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
HuggingFaceResults.displayName = 'HuggingFaceResults';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
|
||||
@@ -10,7 +10,7 @@ type SimpleImportModelConfig = {
|
||||
inplace: boolean;
|
||||
};
|
||||
|
||||
export const InstallModelForm = memo(() => {
|
||||
export const InstallModelForm = () => {
|
||||
const [installModel, { isLoading }] = useInstallModel();
|
||||
|
||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||
@@ -74,6 +74,4 @@ export const InstallModelForm = memo(() => {
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
});
|
||||
|
||||
InstallModelForm.displayName = 'InstallModelForm';
|
||||
};
|
||||
|
||||
@@ -2,12 +2,12 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ModelInstallQueue = memo(() => {
|
||||
export const ModelInstallQueue = () => {
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
@@ -61,6 +61,4 @@ export const ModelInstallQueue = memo(() => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelInstallQueue.displayName = 'ModelInstallQueue';
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { ModelInstallJob } from 'services/api/types';
|
||||
@@ -25,7 +25,7 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
@@ -124,9 +124,7 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelInstallQueueItem.displayName = 'ModelInstallQueueItem';
|
||||
};
|
||||
|
||||
type TooltipLabelProps = {
|
||||
installJob: ModelInstallJob;
|
||||
@@ -134,7 +132,7 @@ type TooltipLabelProps = {
|
||||
source: string;
|
||||
};
|
||||
|
||||
const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
@@ -158,6 +156,4 @@ const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipLabel.displayName = 'TooltipLabel';
|
||||
};
|
||||
|
||||
@@ -2,13 +2,13 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel,
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = memo(() => {
|
||||
export const ScanModelsForm = () => {
|
||||
const scanPath = useAppSelector((state) => state.modelmanagerV2.scanPath);
|
||||
const dispatch = useAppDispatch();
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@@ -56,6 +56,4 @@ export const ScanModelsForm = memo(() => {
|
||||
{data && <ScanModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelsForm.displayName = 'ScanModelsForm';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@@ -8,7 +8,7 @@ type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
installModel: (source: string) => void;
|
||||
};
|
||||
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInstall = useCallback(() => {
|
||||
@@ -30,6 +30,4 @@ export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelResultItem.displayName = 'ScanModelResultItem';
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@@ -25,7 +25,7 @@ type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
||||
};
|
||||
|
||||
export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [inplace, setInplace] = useState(true);
|
||||
@@ -116,6 +116,4 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ScanModelsResults.displayName = 'ScanModelsResults';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@@ -9,7 +9,7 @@ import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
type Props = {
|
||||
result: GetStarterModelsResponse[number];
|
||||
};
|
||||
export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||
export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const allSources = useMemo(() => {
|
||||
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
|
||||
@@ -47,6 +47,4 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsResultItem.displayName = 'StarterModelsResultItem';
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||
import { memo } from 'react';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResults } from './StarterModelsResults';
|
||||
|
||||
export const StarterModelsForm = memo(() => {
|
||||
export const StarterModelsForm = () => {
|
||||
const { isLoading, data } = useGetStarterModelsQuery();
|
||||
|
||||
return (
|
||||
@@ -14,6 +13,4 @@ export const StarterModelsForm = memo(() => {
|
||||
{data && <StarterModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsForm.displayName = 'StarterModelsForm';
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@@ -12,7 +12,7 @@ type StarterModelsResultsProps = {
|
||||
results: NonNullable<GetStarterModelsResponse>;
|
||||
};
|
||||
|
||||
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
|
||||
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
@@ -79,6 +79,4 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StarterModelsResults.displayName = 'StarterModelsResults';
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@in
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
@@ -12,7 +12,7 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const $installModelsTab = atom(0);
|
||||
|
||||
export const InstallModels = memo(() => {
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
const index = useStore($installModelsTab);
|
||||
const onChange = useCallback((index: number) => {
|
||||
@@ -49,6 +49,4 @@ export const InstallModels = memo(() => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
InstallModels.displayName = 'InstallModels';
|
||||
};
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||
|
||||
export const ModelManager = memo(() => {
|
||||
export const ModelManager = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const handleClickAddModel = useCallback(() => {
|
||||
@@ -29,6 +29,4 @@ export const ModelManager = memo(() => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelManager.displayName = 'ModelManager';
|
||||
};
|
||||
|
||||
@@ -21,8 +21,7 @@ import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -24,21 +23,15 @@ const sx: SystemStyleObject = {
|
||||
"&[aria-selected='true']": { bg: 'base.700' },
|
||||
};
|
||||
|
||||
const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
const ModelListItem = (props: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectIsSelected = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectModelManagerV2Slice,
|
||||
(modelManagerV2Slice) => modelManagerV2Slice.selectedModelKey === model.key
|
||||
),
|
||||
[model.key]
|
||||
);
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const { model } = props;
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
dispatch(setSelectedModelKey(model.key));
|
||||
}, [model.key, dispatch]);
|
||||
@@ -50,6 +43,11 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
|
||||
const isSelected = useMemo(() => {
|
||||
return selectedModelKey === model.key;
|
||||
}, [selectedModelKey, model.key]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
|
||||
@@ -3,12 +3,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { t } from 'i18next';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||
|
||||
export const ModelListNavigation = memo(() => {
|
||||
export const ModelListNavigation = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
|
||||
@@ -49,6 +49,4 @@ export const ModelListNavigation = memo(() => {
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelListNavigation.displayName = 'ModelListNavigation';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
||||
import { memo } from 'react';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
@@ -9,7 +8,7 @@ type ModelListWrapperProps = {
|
||||
modelList: AnyModelConfig[];
|
||||
};
|
||||
|
||||
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
||||
const { title, modelList } = props;
|
||||
return (
|
||||
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
|
||||
@@ -18,6 +17,4 @@ export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
))}
|
||||
</StickyScrollable>
|
||||
);
|
||||
});
|
||||
|
||||
ModelListWrapper.displayName = 'ModelListWrapper';
|
||||
};
|
||||
|
||||
@@ -2,12 +2,12 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-libr
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFunnelBold } from 'react-icons/pi';
|
||||
import { objectKeys } from 'tsafe';
|
||||
|
||||
export const ModelTypeFilter = memo(() => {
|
||||
export const ModelTypeFilter = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
@@ -57,6 +57,4 @@ export const ModelTypeFilter = memo(() => {
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
ModelTypeFilter.displayName = 'ModelTypeFilter';
|
||||
};
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = memo(() => {
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
ModelPane.displayName = 'ModelPane';
|
||||
};
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
||||
preprocessor: FormField<string>;
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
|
||||
};
|
||||
|
||||
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
|
||||
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
|
||||
useControlNetOrT2IAdapterDefaultSettings(selectedModelKey);
|
||||
|
||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||
|
||||
@@ -32,12 +30,16 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
||||
};
|
||||
|
||||
updateModel({
|
||||
key: modelConfig.key,
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
@@ -59,9 +61,13 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
}
|
||||
});
|
||||
},
|
||||
[updateModel, modelConfig.key, t, reset]
|
||||
[selectedModelKey, reset, updateModel, t]
|
||||
);
|
||||
|
||||
if (isLoadingDefaultSettings) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||
@@ -83,6 +89,4 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
||||
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -28,7 +28,7 @@ const OPTIONS = [
|
||||
|
||||
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
|
||||
|
||||
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
|
||||
export function DefaultPreprocessor(props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -63,6 +63,4 @@ export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT
|
||||
<Combobox isDisabled={isDisabled} value={value} options={OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultPreprocessor.displayName = 'DefaultPreprocessor';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = MainModelDefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export const DefaultCfgRescaleMultiplier = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultCfgRescaleMultiplier = memo((props: UseControllerProps<MainM
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultCfgRescaleMultiplier.displayName = 'DefaultCfgRescaleMultiplier';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultCfgType = MainModelDefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export const DefaultCfgScale = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultCfgScale(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultCfgScale = memo((props: UseControllerProps<MainModelDefaultS
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultCfgScale.displayName = 'DefaultCfgScale';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -16,7 +16,7 @@ type Props = {
|
||||
optimalDimension: number;
|
||||
};
|
||||
|
||||
export const DefaultHeight = memo(({ control, optimalDimension }: Props) => {
|
||||
export function DefaultHeight({ control, optimalDimension }: Props) {
|
||||
const { field } = useController({ control, name: 'height' });
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
|
||||
@@ -78,6 +78,4 @@ export const DefaultHeight = memo(({ control, optimalDimension }: Props) => {
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultHeight.displayName = 'DefaultHeight';
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -13,7 +13,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultSchedulerType = MainModelDefaultSettingsFormData['scheduler'];
|
||||
|
||||
export const DefaultScheduler = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultScheduler(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -51,6 +51,4 @@ export const DefaultScheduler = memo((props: UseControllerProps<MainModelDefault
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultScheduler.displayName = 'DefaultScheduler';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultSteps = MainModelDefaultSettingsFormData['steps'];
|
||||
|
||||
export const DefaultSteps = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultSteps(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
@@ -74,6 +74,4 @@ export const DefaultSteps = memo((props: UseControllerProps<MainModelDefaultSett
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultSteps.displayName = 'DefaultSteps';
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -15,7 +15,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
||||
|
||||
type DefaultVaeType = MainModelDefaultSettingsFormData['vae'];
|
||||
|
||||
export const DefaultVae = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
@@ -64,6 +64,4 @@ export const DefaultVae = memo((props: UseControllerProps<MainModelDefaultSettin
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultVae.displayName = 'DefaultVae';
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -17,7 +17,7 @@ const options = [
|
||||
|
||||
type DefaultVaePrecisionType = MainModelDefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export const DefaultVaePrecision = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||
export function DefaultVaePrecision(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -52,6 +52,4 @@ export const DefaultVaePrecision = memo((props: UseControllerProps<MainModelDefa
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultVaePrecision.displayName = 'DefaultVaePrecision';
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -16,7 +16,7 @@ type Props = {
|
||||
optimalDimension: number;
|
||||
};
|
||||
|
||||
export const DefaultWidth = memo(({ control, optimalDimension }: Props) => {
|
||||
export function DefaultWidth({ control, optimalDimension }: Props) {
|
||||
const { field } = useController({ control, name: 'width' });
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
|
||||
@@ -78,6 +78,4 @@ export const DefaultWidth = memo(({ control, optimalDimension }: Props) => {
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
DefaultWidth.displayName = 'DefaultWidth';
|
||||
}
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
@@ -37,16 +35,16 @@ export type MainModelDefaultSettingsFormData = {
|
||||
height: FormField<number>;
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: MainModelConfig;
|
||||
};
|
||||
|
||||
export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
export const MainModelDefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
|
||||
const optimalDimension = useMemo(() => getOptimalDimension(modelConfig), [modelConfig]);
|
||||
const {
|
||||
defaultSettingsDefaults,
|
||||
isLoading: isLoadingDefaultSettings,
|
||||
optimalDimension,
|
||||
} = useMainModelDefaultSettings(selectedModelKey);
|
||||
|
||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
||||
@@ -96,6 +94,10 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
[selectedModelKey, reset, updateModel, t]
|
||||
);
|
||||
|
||||
if (isLoadingDefaultSettings) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||
@@ -124,6 +126,4 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
MainModelDefaultSettings.displayName = 'MainModelDefaultSettings';
|
||||
};
|
||||
|
||||
@@ -1,47 +1,120 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = memo(() => {
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelConfigs, isLoading } = useGetModelConfigsQuery();
|
||||
const modelConfig = useMemo(() => {
|
||||
if (!modelConfigs) {
|
||||
return null;
|
||||
}
|
||||
if (selectedModelKey === null) {
|
||||
return null;
|
||||
}
|
||||
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
if (!modelConfig) {
|
||||
return null;
|
||||
}
|
||||
const form = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: data,
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
return modelConfig;
|
||||
}, [modelConfigs, selectedModelKey]);
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
}
|
||||
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: data.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
form.reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
toast({
|
||||
id: 'MODEL_UPDATED',
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((_) => {
|
||||
form.reset();
|
||||
toast({
|
||||
id: 'MODEL_UPDATE_FAILED',
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
},
|
||||
[dispatch, data?.key, form, t, updateModel]
|
||||
);
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner label={t('common.loading')} />;
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!modelConfig) {
|
||||
return <IAINoContentFallback label={t('common.somethingWentWrong')} icon={PiExclamationMarkBold} />;
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
if (selectedModelMode === 'view') {
|
||||
return <ModelView modelConfig={modelConfig} />;
|
||||
}
|
||||
|
||||
return <ModelEdit modelConfig={modelConfig} />;
|
||||
});
|
||||
|
||||
Model.displayName = 'Model';
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex alignItems="flex-start" gap={4}>
|
||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||
<Flex gap={2}>
|
||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||
{data.name}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
|
||||
{selectedModelMode === 'view' && <ModelEditButton />}
|
||||
{selectedModelMode === 'edit' && (
|
||||
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
)}
|
||||
{selectedModelMode === 'edit' && (
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
{data.source && (
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Text noOfLines={3}>{data.description}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
interface Props {
|
||||
label: string;
|
||||
value: string | null | undefined;
|
||||
}
|
||||
|
||||
export const ModelAttrView = memo(({ label, value }: Props) => {
|
||||
export const ModelAttrView = ({ label, value }: Props) => {
|
||||
return (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
@@ -15,6 +14,4 @@ export const ModelAttrView = memo(({ label, value }: Props) => {
|
||||
</Text>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ModelAttrView.displayName = 'ModelAttrView';
|
||||
};
|
||||
|
||||
@@ -8,46 +8,52 @@ import {
|
||||
UnorderedList,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
interface ModelConvertProps {
|
||||
modelConfig: CheckpointModelConfig;
|
||||
modelKey: string | null;
|
||||
}
|
||||
|
||||
export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
export const ModelConvertButton = (props: ModelConvertProps) => {
|
||||
const { modelKey } = props;
|
||||
const { t } = useTranslation();
|
||||
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
|
||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
if (!modelConfig || isLoading) {
|
||||
if (!data || isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
const toastId = `CONVERTING_MODEL_${modelConfig.key}`;
|
||||
const toastId = `CONVERTING_MODEL_${data.key}`;
|
||||
toast({
|
||||
id: toastId,
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${modelConfig.name}`,
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
|
||||
status: 'info',
|
||||
});
|
||||
|
||||
convertModel(modelConfig.key)
|
||||
convertModel(data?.key)
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${modelConfig.name}`, status: 'success' });
|
||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' });
|
||||
})
|
||||
.catch(() => {
|
||||
toast({
|
||||
id: toastId,
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${modelConfig.name}`,
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
}, [modelConfig, isLoading, t, convertModel]);
|
||||
}, [data, isLoading, t, convertModel]);
|
||||
|
||||
if (data?.format !== 'checkpoint') {
|
||||
return;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -62,7 +68,7 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
🧨 {t('modelManager.convert')}
|
||||
</Button>
|
||||
<ConfirmationAlertDialog
|
||||
title={`${t('modelManager.convert')} ${modelConfig.name}`}
|
||||
title={`${t('modelManager.convert')} ${data?.name}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
acceptButtonText={`${t('modelManager.convert')}`}
|
||||
isOpen={isOpen}
|
||||
@@ -90,6 +96,4 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
</ConfirmationAlertDialog>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ModelConvertButton.displayName = 'ModelConvertButton';
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
@@ -8,154 +7,96 @@ import {
|
||||
Heading,
|
||||
Input,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { type SubmitHandler, useForm } from 'react-hook-form';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
form: UseFormReturn<UpdateModelArg['body']>;
|
||||
onSubmit: SubmitHandler<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const stringFieldOptions = {
|
||||
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
};
|
||||
|
||||
export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
export const ModelEdit = ({ form }: Props) => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { t } = useTranslation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const form = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: modelConfig,
|
||||
mode: 'onChange',
|
||||
});
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: modelConfig.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
form.reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
toast({
|
||||
id: 'MODEL_UPDATED',
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((_) => {
|
||||
form.reset();
|
||||
toast({
|
||||
id: 'MODEL_UPDATE_FAILED',
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
},
|
||||
[dispatch, modelConfig.key, form, t, updateModel]
|
||||
);
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
}, [dispatch]);
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
<Button flexShrink={0} size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</ModelHeader>
|
||||
<Flex flexDir="column" h="full">
|
||||
<form>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl
|
||||
flexDir="column"
|
||||
alignItems="flex-start"
|
||||
gap={1}
|
||||
isInvalid={Boolean(form.formState.errors.name)}
|
||||
>
|
||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||
<Flex flexDir="column" h="full">
|
||||
<form>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
|
||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||
|
||||
{form.formState.errors.name?.message && (
|
||||
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||
)}
|
||||
{form.formState.errors.name?.message && (
|
||||
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||
)}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex gap="4" alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea {...form.register('description')} minH={32} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex gap="4" alignItems="center">
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea {...form.register('description')} minH={32} />
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{modelConfig.type === 'main' && (
|
||||
)}
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelEdit.displayName = 'ModelEdit';
|
||||
};
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
|
||||
export const ModelEditButton = memo(() => {
|
||||
export const ModelEditButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -18,6 +18,4 @@ export const ModelEditButton = memo(() => {
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
ModelEditButton.displayName = 'ModelEditButton';
|
||||
};
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
modelConfig: AnyModelConfig;
|
||||
}>;
|
||||
|
||||
export const ModelHeader = memo(({ modelConfig, children }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex alignItems="flex-start" gap={4}>
|
||||
<ModelImageUpload model_key={modelConfig.key} model_image={modelConfig.cover_image} />
|
||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||
<Flex gap={2}>
|
||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||
{modelConfig.name}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
{children}
|
||||
</Flex>
|
||||
{modelConfig.source && (
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
{t('modelManager.source')}: {modelConfig.source}
|
||||
</Text>
|
||||
)}
|
||||
<Text noOfLines={3}>{modelConfig.description}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelHeader.displayName = 'ModelHeader';
|
||||
@@ -1,67 +1,55 @@
|
||||
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
};
|
||||
|
||||
export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
export const ModelView = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||
<ModelConvertButton modelConfig={modelConfig} />
|
||||
)}
|
||||
<ModelEditButton />
|
||||
</ModelHeader>
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
|
||||
<ModelAttrView label={t('common.format')} value={modelConfig.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
|
||||
{modelConfig.type === 'main' && (
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelConfig.repo_variant} />
|
||||
)}
|
||||
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelConfig.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelConfig.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelConfig.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
{modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && (
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelConfig.image_encoder_model_id} />
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
|
||||
</Box>
|
||||
</Flex>
|
||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelView.displayName = 'ModelView';
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Switch, typedMemo } from '@invoke-ai/ui-library';
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
@@ -6,7 +6,7 @@ import { useController } from 'react-hook-form';
|
||||
|
||||
import type { FormField } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
|
||||
export const SettingToggle = typedMemo(<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) => {
|
||||
export function SettingToggle<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
@@ -25,6 +25,4 @@ export const SettingToggle = typedMemo(<T, F extends Record<string, FormField<T>
|
||||
);
|
||||
|
||||
return <Switch size="sm" isChecked={value} onChange={onChange} />;
|
||||
});
|
||||
|
||||
SettingToggle.displayName = 'SettingToggle';
|
||||
}
|
||||
|
||||
@@ -9,19 +9,19 @@ import {
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig, MainModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import { isLoRAModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelConfig: MainModelConfig | LoRAModelConfig;
|
||||
};
|
||||
|
||||
export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { currentData: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
@@ -31,6 +31,9 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
if (!modelConfig || (!isNonRefinerMainModelConfig(modelConfig) && !isLoRAModelConfig(modelConfig))) {
|
||||
return [];
|
||||
}
|
||||
return modelConfig?.trigger_phrases || [];
|
||||
}, [modelConfig]);
|
||||
|
||||
@@ -45,6 +48,10 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
@@ -52,18 +59,22 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
setPhrase('');
|
||||
|
||||
await updateModel({
|
||||
key: modelConfig.key,
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
}, [phrase, triggerPhrases, updateModel, modelConfig.key]);
|
||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await updateModel({ key: modelConfig.key, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[triggerPhrases, updateModel, modelConfig]
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||
@@ -92,9 +103,7 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{errors.map((error) => (
|
||||
<FormErrorMessage key={error}>{error}</FormErrorMessage>
|
||||
))}
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
@@ -109,6 +118,4 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TriggerPhrases.displayName = 'TriggerPhrases';
|
||||
};
|
||||
|
||||
@@ -59,19 +59,17 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
for (const edge of copiedEdges) {
|
||||
if (edge.source === node.id) {
|
||||
edge.source = id;
|
||||
} else if (edge.target === node.id) {
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
if (edge.target === node.id) {
|
||||
edge.target = id;
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
}
|
||||
node.id = id;
|
||||
node.data.id = id;
|
||||
});
|
||||
|
||||
copiedEdges.forEach((edge) => {
|
||||
// Copied edges need a fresh id too
|
||||
edge.id = uuidv4();
|
||||
});
|
||||
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
// Deselect existing nodes
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { creativityChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -26,9 +25,7 @@ const ParamCreativity = () => {
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="creativity">
|
||||
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={creativity}
|
||||
defaultValue={initial}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -38,9 +37,7 @@ const ParamSpandrelModel = () => {
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<InformationalPopover feature="upscaleModel">
|
||||
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box w="full">
|
||||
<Combobox
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { structureChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -26,9 +25,7 @@ const ParamStructure = () => {
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="structure">
|
||||
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={structure}
|
||||
defaultValue={initial}
|
||||
|
||||
@@ -64,7 +64,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
const badges = useAppSelector(selectBadges);
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||
id: `'advanced-settings-${activeTabName}`,
|
||||
id: 'advanced-settings',
|
||||
defaultIsOpen: false,
|
||||
});
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { filter } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -27,7 +26,6 @@ const formLabelProps: FormLabelProps = {
|
||||
export const GenerationSettingsAccordion = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
@@ -44,8 +42,8 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
defaultIsOpen: false,
|
||||
});
|
||||
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
|
||||
id: `generation-settings-${activeTabName}`,
|
||||
defaultIsOpen: activeTabName !== 'upscaling',
|
||||
id: 'generation-settings',
|
||||
defaultIsOpen: true,
|
||||
});
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { scaleChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -23,9 +22,7 @@ export const UpscaleScaleSlider = memo(() => {
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical" gap={0}>
|
||||
<InformationalPopover feature="scale">
|
||||
<FormLabel m={0}>{t('upscaling.scale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<FormLabel m={0}>{t('upscaling.scale')}</FormLabel>
|
||||
<Flex w="full" gap={4}>
|
||||
<CompositeSlider
|
||||
min={2}
|
||||
|
||||
@@ -18,6 +18,5 @@ export const useStandaloneAccordionToggle = (arg: UseStandaloneAccordionToggleAr
|
||||
const onToggle = useCallback(() => {
|
||||
dispatch(accordionStateChanged({ id: arg.id, isOpen: !isOpen }));
|
||||
}, [arg.id, dispatch, isOpen]);
|
||||
|
||||
return { isOpen, onToggle };
|
||||
};
|
||||
|
||||
@@ -242,6 +242,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
}
|
||||
return tags;
|
||||
},
|
||||
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
||||
transformResponse: (response: GetModelConfigsResponse) => {
|
||||
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||
},
|
||||
|
||||
@@ -54,7 +54,7 @@ export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.2.7"
|
||||
__version__ = "4.2.7rc1"
|
||||
|
||||
Reference in New Issue
Block a user