mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 13:38:00 -05:00
Compare commits
37 Commits
v5.4.0
...
maryhipp/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bd1f4a4f4 | ||
|
|
28864f6d7f | ||
|
|
c63fe5e9bb | ||
|
|
674f530501 | ||
|
|
a01d44f813 | ||
|
|
63fb3a15e9 | ||
|
|
4d0837541b | ||
|
|
999809b4c7 | ||
|
|
c452edfb9f | ||
|
|
ad2cdbd8a2 | ||
|
|
f15c24bfa7 | ||
|
|
d1f653f28c | ||
|
|
244465d3a6 | ||
|
|
c6236ab70c | ||
|
|
644d5cb411 | ||
|
|
bb0a630416 | ||
|
|
2148ae9287 | ||
|
|
42d242609c | ||
|
|
fd0a52392b | ||
|
|
e64415d59a | ||
|
|
1871e0bdbf | ||
|
|
3ae9a965c2 | ||
|
|
85932e35a7 | ||
|
|
41b07a56cc | ||
|
|
54064c0cb8 | ||
|
|
68284b37fa | ||
|
|
ae5bc6f5d6 | ||
|
|
6dc16c9f54 | ||
|
|
faa9ac4e15 | ||
|
|
d0460849b0 | ||
|
|
bed3c2dd77 | ||
|
|
916ddd17d7 | ||
|
|
accfa7407f | ||
|
|
908db31e48 | ||
|
|
b70f632b26 | ||
|
|
d07a6385ab | ||
|
|
68df612fa1 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -19,3 +19,4 @@
|
||||
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
|
||||
- [ ] _Tests added / updated (if applicable)_
|
||||
- [ ] _Documentation added / updated (if applicable)_
|
||||
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
|
||||
|
||||
@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
|
||||
|
||||
version: str = Field(description="App version")
|
||||
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
@@ -751,7 +751,7 @@ async def convert_model(
|
||||
|
||||
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
|
||||
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
|
||||
converted_model = loader.load_model(model_config)
|
||||
converted_model = loader.load_model(model_config, queue_id="default")
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
|
||||
@@ -31,6 +31,7 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadEventBase,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
@@ -53,6 +54,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
class ModelLoadSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io model loading room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
queue_id: str
|
||||
|
||||
|
||||
QUEUE_EVENTS = {
|
||||
InvocationStartedEvent,
|
||||
InvocationProgressEvent,
|
||||
@@ -69,8 +77,6 @@ MODEL_EVENTS = {
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
@@ -79,6 +85,11 @@ MODEL_EVENTS = {
|
||||
ModelInstallErrorEvent,
|
||||
}
|
||||
|
||||
MODEL_LOAD_EVENTS = {
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
}
|
||||
|
||||
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
|
||||
|
||||
|
||||
@@ -101,6 +112,7 @@ class SocketIO:
|
||||
|
||||
register_events(QUEUE_EVENTS, self._handle_queue_event)
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(MODEL_LOAD_EVENTS, self._handle_model_load_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
@@ -115,9 +127,18 @@ class SocketIO:
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_sub_model_load(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_unsub_model_load(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
|
||||
async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -63,12 +63,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer, queue_id=context.util.get_queue_id())
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder, queue_id=context.util.get_queue_id())
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, queue_id=context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -137,8 +137,8 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer, queue_id=context.util.get_queue_id())
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder, queue_id=context.util.get_queue_id())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
|
||||
@@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size],
|
||||
queue_id=self._context.util.get_queue_id(),
|
||||
loader=load_depth_anything,
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
@@ -60,7 +60,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
|
||||
@@ -124,7 +124,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = blur_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -88,7 +88,7 @@ def get_scheduler(
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
control_model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
control_image_field = control_info.image
|
||||
@@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
@@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
image_prompts = []
|
||||
for single_ip_adapter in ip_adapters:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||
with context.models.load(
|
||||
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
|
||||
) as ip_adapter_model:
|
||||
assert isinstance(ip_adapter_model, IPAdapter)
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
image_encoder_model_info = context.models.load(
|
||||
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
|
||||
)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
||||
ip_adapters, image_prompts, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
|
||||
mask_field = single_ip_adapter.mask
|
||||
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||
@@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
t2i_adapter_loaded_model = context.models.load(
|
||||
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
|
||||
)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@@ -926,7 +938,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
@@ -989,13 +1001,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
|
||||
@@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
|
||||
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
|
||||
)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
|
||||
@@ -29,10 +29,10 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
|
||||
@@ -183,7 +183,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
@@ -468,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
||||
controlnet_models = [
|
||||
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
|
||||
]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
@@ -479,7 +481,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
|
||||
controlnet_conds.append(
|
||||
InstantXControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
@@ -590,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
pos_images.append(pos_image)
|
||||
neg_images.append(neg_image)
|
||||
|
||||
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
|
||||
with context.models.load(
|
||||
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
|
||||
) as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
||||
@@ -620,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
|
||||
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
|
||||
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
|
||||
if ip_adapter_field.mask is not None:
|
||||
@@ -649,7 +655,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -77,8 +77,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -118,7 +118,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -52,7 +52,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -54,7 +54,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=GroundingDinoInvocation._load_grounding_dino,
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
||||
@@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, ControlNetHED_Apache2)
|
||||
|
||||
@@ -111,7 +111,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -36,7 +36,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
def infill(self, image: Image.Image, queue_id: str) -> Image.Image:
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +56,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
infilled_image = self.infill(input_image, context.util.get_queue_id())
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
@@ -74,7 +74,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
@@ -93,7 +93,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
return output.infilled
|
||||
|
||||
@@ -107,7 +107,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
width = int(image.width / self.downscale)
|
||||
@@ -131,9 +131,10 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
queue_id=queue_id,
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
@@ -144,7 +145,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
@@ -166,5 +167,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
|
||||
@@ -57,7 +57,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
@@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartEdgeDetector.get_model_url(self.coarse)
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, Generator)
|
||||
|
||||
@@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartAnimeEdgeDetector.get_model_url()
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, UnetGenerator)
|
||||
|
||||
@@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, MobileV2_MLSD_Large)
|
||||
|
||||
@@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, NNET)
|
||||
|
||||
@@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, PiDiNet)
|
||||
|
||||
@@ -147,7 +147,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
|
||||
@@ -44,7 +44,7 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
|
||||
@@ -86,8 +86,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -127,8 +127,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -193,7 +193,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -125,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=SegmentAnythingInvocation._load_sam_model,
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
|
||||
@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
|
||||
|
||||
def step_callback(step: int, total_steps: int) -> None:
|
||||
context.util.signal_progress(
|
||||
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
|
||||
@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
|
||||
@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
raise ValueError(msg)
|
||||
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
|
||||
)
|
||||
|
||||
with loadnet as loadnet_model:
|
||||
|
||||
@@ -131,15 +131,17 @@ class EventServiceBase:
|
||||
|
||||
# region Model loading
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
def emit_model_load_started(
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -383,12 +383,14 @@ class DownloadErrorEvent(DownloadEventBase):
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
class ModelLoadEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
class ModelLoadStartedEvent(ModelLoadEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
@@ -397,12 +399,14 @@ class ModelLoadStartedEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
class ModelLoadCompleteEvent(ModelLoadEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
@@ -411,8 +415,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for model events"""
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
|
||||
@@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
@@ -29,7 +31,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
@@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
@@ -60,7 +62,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
self._invoker.services.events.emit_model_load_started(model_config, queue_id, submodel_type)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@@ -70,12 +72,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, queue_id, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
|
||||
@@ -351,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
identifier: Union[str, "ModelIdentifierField"],
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model.
|
||||
|
||||
@@ -368,14 +371,19 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, queue_id, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
name: str,
|
||||
base: BaseModelType,
|
||||
type: ModelType,
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model by its attributes.
|
||||
|
||||
@@ -397,7 +405,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
return self._services.model_manager.load.load_model(configs[0], queue_id, submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Get a model's config.
|
||||
@@ -472,6 +480,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_local_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@@ -489,11 +498,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@@ -514,7 +526,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
@@ -535,6 +549,14 @@ class UtilInterface(InvocationContextInterface):
|
||||
super().__init__(services, data)
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
def get_queue_id(self) -> str:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._data.queue_item.queue_id
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def generate_ti_list(
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
loaded_model = context.models.load(name_or_key, queue_id=context.util.get_queue_id())
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
@@ -30,7 +30,7 @@ def generate_ti_list(
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion, queue_id=context.util.get_queue_id()
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
||||
@@ -169,17 +169,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
try:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
return ClipVariantType.L
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return ClipVariantType.L
|
||||
except Exception:
|
||||
return ClipVariantType.L
|
||||
|
||||
@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=unet,
|
||||
|
||||
@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
|
||||
@callback(ExtensionCallbackType.SETUP)
|
||||
def setup(self, ctx: DenoiseContext):
|
||||
t2i_model: T2IAdapter
|
||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
|
||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||
|
||||
self._adapter_state = self._run_model(
|
||||
|
||||
@@ -9,6 +9,7 @@ const config: KnipConfig = {
|
||||
'src/services/api/schema.ts',
|
||||
'src/features/nodes/types/v1/**',
|
||||
'src/features/nodes/types/v2/**',
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// TODO(psyche): restore HRF functionality?
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 895 KiB |
@@ -997,6 +997,7 @@
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"copyImage": "Copy Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"noRasterLayers": "No Raster Layers",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"guidance": "Guidance",
|
||||
@@ -1412,8 +1413,9 @@
|
||||
"paramDenoisingStrength": {
|
||||
"heading": "Denoising Strength",
|
||||
"paragraphs": [
|
||||
"How much noise is added to the input image.",
|
||||
"0 will result in an identical image, while 1 will result in a completely new image."
|
||||
"Controls how much the generated image varies from the raster layer(s).",
|
||||
"Lower strength stays closer to the combined visible raster layers. Higher strength relies more on the global prompt.",
|
||||
"When there are no raster layers with visible content, this setting is ignored."
|
||||
]
|
||||
},
|
||||
"paramHeight": {
|
||||
@@ -1662,6 +1664,7 @@
|
||||
"mergeDown": "Merge Down",
|
||||
"mergeVisibleOk": "Merged layers",
|
||||
"mergeVisibleError": "Error merging layers",
|
||||
"mergingLayers": "Merging layers",
|
||||
"clearHistory": "Clear History",
|
||||
"bboxOverlay": "Show Bbox Overlay",
|
||||
"resetCanvas": "Reset Canvas",
|
||||
@@ -1774,9 +1777,10 @@
|
||||
"newCanvasSession": "New Canvas Session",
|
||||
"newCanvasSessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be staged on the canvas.",
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
|
||||
"controlMode": {
|
||||
"controlMode": "Control Mode",
|
||||
"balanced": "Balanced",
|
||||
"balanced": "Balanced (recommended)",
|
||||
"prompt": "Prompt",
|
||||
"control": "Control",
|
||||
"megaControl": "Mega Control"
|
||||
@@ -1815,6 +1819,9 @@
|
||||
"process": "Process",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"advanced": "Advanced",
|
||||
"processingLayerWith": "Processing layer with the {{type}} filter.",
|
||||
"forMoreControl": "For more control, click Advanced below.",
|
||||
"spandrel_filter": {
|
||||
"label": "Image-to-Image Model",
|
||||
"description": "Run an image-to-image model on the selected layer.",
|
||||
@@ -2095,9 +2102,8 @@
|
||||
},
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"line1": "<ItalicComponent>Select Object</ItalicComponent> tool for precise object selection and editing",
|
||||
"line2": "Expanded Flux support, now with Global Reference Images",
|
||||
"line3": "Improved tooltips and context menus",
|
||||
"line1": "<StrongComponent>Layer Merging</StrongComponent>: New <StrongComponent>Merge Down</StrongComponent> and improved <StrongComponent>Merge Visible</StrongComponent> for all layers, with special handling for Regional Guidance and Control Layers.",
|
||||
"line2": "<StrongComponent>HF Token Support</StrongComponent>: Upload models that require Hugging Face authentication.",
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
"watchUiUpdatesOverview": "Watch UI Updates Overview"
|
||||
|
||||
@@ -2,7 +2,7 @@ import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
@@ -23,7 +23,7 @@ import type {
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
@@ -163,11 +163,10 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
const defaultControlAdapter = selectDefaultControlAdapter(state);
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
controlAdapter: defaultControlAdapter,
|
||||
controlAdapter: deepClone(initialControlNet),
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
|
||||
@@ -4,8 +4,10 @@ import { atom } from 'nanostores';
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
// export const $false: ReadableAtom<boolean> = atom(false);
|
||||
export const $false: ReadableAtom<boolean> = atom(false);
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { PopoverProps } from '@invoke-ai/ui-library';
|
||||
import commercialLicenseBg from 'public/assets/images/commercial-license-bg.png';
|
||||
import denoisingStrength from 'public/assets/images/denoising-strength.png';
|
||||
|
||||
export type Feature =
|
||||
| 'clipSkip'
|
||||
@@ -125,7 +126,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
|
||||
},
|
||||
infillMethod: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158841-infill-and-scaling',
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
|
||||
},
|
||||
scaleBeforeProcessing: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158841',
|
||||
@@ -138,6 +139,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
|
||||
},
|
||||
paramDenoisingStrength: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-image-to-image',
|
||||
image: denoisingStrength,
|
||||
},
|
||||
paramHrf: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000096700-how-can-i-get-larger-images-what-does-upscaling-do-',
|
||||
|
||||
57
invokeai/frontend/web/src/common/components/WavyLine.tsx
Normal file
57
invokeai/frontend/web/src/common/components/WavyLine.tsx
Normal file
@@ -0,0 +1,57 @@
|
||||
type Props = {
|
||||
/**
|
||||
* The amplitude of the wave. 0 is a straight line, higher values create more pronounced waves.
|
||||
*/
|
||||
amplitude: number;
|
||||
/**
|
||||
* The number of segments in the line. More segments create a smoother wave.
|
||||
*/
|
||||
segments?: number;
|
||||
/**
|
||||
* The color of the wave.
|
||||
*/
|
||||
stroke: string;
|
||||
/**
|
||||
* The width of the wave.
|
||||
*/
|
||||
strokeWidth: number;
|
||||
/**
|
||||
* The width of the SVG.
|
||||
*/
|
||||
width: number;
|
||||
/**
|
||||
* The height of the SVG.
|
||||
*/
|
||||
height: number;
|
||||
};
|
||||
|
||||
const WavyLine = ({ amplitude, stroke, strokeWidth, width, height, segments = 5 }: Props) => {
|
||||
// Calculate the path dynamically based on waviness
|
||||
const generatePath = () => {
|
||||
if (amplitude === 0) {
|
||||
// If waviness is 0, return a straight line
|
||||
return `M0,${height / 2} L${width},${height / 2}`;
|
||||
}
|
||||
|
||||
const clampedAmplitude = Math.min(height / 2, amplitude); // Cap amplitude to half the height
|
||||
const segmentWidth = width / segments;
|
||||
let path = `M0,${height / 2}`; // Start in the middle of the left edge
|
||||
|
||||
// Loop through each segment and alternate the y position to create waves
|
||||
for (let i = 1; i <= segments; i++) {
|
||||
const x = i * segmentWidth;
|
||||
const y = height / 2 + (i % 2 === 0 ? clampedAmplitude : -clampedAmplitude);
|
||||
path += ` Q${x - segmentWidth / 2},${y} ${x},${height / 2}`;
|
||||
}
|
||||
|
||||
return path;
|
||||
};
|
||||
|
||||
return (
|
||||
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} xmlns="http://www.w3.org/2000/svg">
|
||||
<path d={generatePath()} fill="none" stroke={stroke} strokeWidth={strokeWidth} />
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export default WavyLine;
|
||||
@@ -7,6 +7,8 @@ import { EntityListSelectedEntityActionBar } from 'features/controlLayers/compon
|
||||
import { selectHasEntities } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useRef } from 'react';
|
||||
|
||||
import { ParamDenoisingStrength } from './ParamDenoisingStrength';
|
||||
|
||||
export const CanvasLayersPanelContent = memo(() => {
|
||||
const hasEntities = useAppSelector(selectHasEntities);
|
||||
const layersPanelFocusRef = useRef<HTMLDivElement>(null);
|
||||
@@ -16,6 +18,8 @@ export const CanvasLayersPanelContent = memo(() => {
|
||||
<Flex ref={layersPanelFocusRef} flexDir="column" gap={2} w="full" h="full">
|
||||
<EntityListSelectedEntityActionBar />
|
||||
<Divider py={0} />
|
||||
<ParamDenoisingStrength />
|
||||
<Divider py={0} />
|
||||
{!hasEntities && <CanvasAddEntityButtons />}
|
||||
{hasEntities && <CanvasEntityList />}
|
||||
</Flex>
|
||||
|
||||
@@ -7,7 +7,7 @@ import { CanvasEntityPreviewImage } from 'features/controlLayers/components/comm
|
||||
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { ControlLayerBadges } from 'features/controlLayers/components/ControlLayer/ControlLayerBadges';
|
||||
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
|
||||
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
|
||||
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
@@ -41,7 +41,7 @@ export const ControlLayer = memo(({ id }: Props) => {
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<ControlLayerControlAdapter />
|
||||
<ControlLayerSettings />
|
||||
</CanvasEntitySettingsWrapper>
|
||||
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
|
||||
</CanvasEntityContainer>
|
||||
|
||||
@@ -6,6 +6,7 @@ import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginE
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
|
||||
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
|
||||
import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
@@ -16,6 +17,7 @@ import {
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { getFilterForModel } from 'features/controlLayers/store/filters';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
@@ -44,6 +46,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
|
||||
const filter = useEntityFilter(entityIdentifier);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const adapter = useEntityAdapterContext('control_layer');
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
@@ -69,8 +72,43 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
||||
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
|
||||
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
|
||||
// filter config.
|
||||
const isFiltering = adapter.filterer.$isFiltering.get();
|
||||
const isSimple = adapter.filterer.$simple.get();
|
||||
// If we are filtering and _not_ in simple mode, that means the user has clicked Advanced. They want to be in control
|
||||
// of the settings. Bail early without doing anything else.
|
||||
if (isFiltering && !isSimple) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Else, we are in simple mode and will take care of some things for the user.
|
||||
|
||||
// First, check if the newly-selected model has a default filter. It may not - for example, Tile controlnet models
|
||||
// don't have a default filter.
|
||||
const defaultFilterForNewModel = getFilterForModel(modelConfig);
|
||||
|
||||
if (!defaultFilterForNewModel) {
|
||||
// The user has chosen a model that doesn't have a default filter - cancel any in-progress filtering and bail.
|
||||
if (isFiltering) {
|
||||
adapter.filterer.cancel();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, we know the user has selected a model that has a default filter. We need to either start filtering
|
||||
// with that default filter, or update the existing filter config to match the new model's default filter.
|
||||
const filterConfig = defaultFilterForNewModel.buildDefaults();
|
||||
if (isFiltering) {
|
||||
adapter.filterer.$filterConfig.set(filterConfig);
|
||||
} else {
|
||||
adapter.filterer.start(filterConfig);
|
||||
}
|
||||
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
|
||||
// no-op if auto-processing is already enabled, because the process method is debounced.
|
||||
adapter.filterer.process();
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
[adapter.filterer, dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
|
||||
import { ControlLayerSettingsEmptyState } from 'features/controlLayers/components/ControlLayer/ControlLayerSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const ControlLayerSettings = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const isEmpty = useEntityIsEmpty(entityIdentifier);
|
||||
|
||||
if (isEmpty) {
|
||||
return <ControlLayerSettingsEmptyState />;
|
||||
}
|
||||
|
||||
return <ControlLayerControlAdapter />;
|
||||
});
|
||||
|
||||
ControlLayerSettings.displayName = 'ControlLayerSettings';
|
||||
@@ -0,0 +1,50 @@
|
||||
import { Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans } from 'react-i18next';
|
||||
import type { PostUploadAction } from 'services/api/types';
|
||||
|
||||
export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const postUploadAction = useMemo<PostUploadAction>(
|
||||
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ postUploadAction });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.controlLayerEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerSettingsEmptyState.displayName = 'ControlLayerSettingsEmptyState';
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
MenuList,
|
||||
Spacer,
|
||||
Spinner,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -28,13 +29,10 @@ import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
const FilterContent = memo(
|
||||
const FilterContentAdvanced = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasImageState = useStore(adapter.filterer.$hasImageState);
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
@@ -73,36 +71,8 @@ const FilterContent = memo(
|
||||
adapter.filterer.saveAs('control_layer');
|
||||
}, [adapter.filterer]);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applyFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.apply,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'cancelFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.cancel,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
p={4}
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
w={420}
|
||||
h="auto"
|
||||
shadow="dark-lg"
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
<>
|
||||
<Flex w="full" gap={4}>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
@@ -169,12 +139,67 @@ const FilterContent = memo(
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContent.displayName = 'FilterContent';
|
||||
FilterContentAdvanced.displayName = 'FilterContentAdvanced';
|
||||
|
||||
const FilterContentSimple = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasImageState = useStore(adapter.filterer.$hasImageState);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
|
||||
}, [config]);
|
||||
|
||||
const onClickAdvanced = useCallback(() => {
|
||||
adapter.filterer.$simple.set(false);
|
||||
}, [adapter.filterer.$simple]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex w="full" gap={4}>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" gap={2} pb={2}>
|
||||
<Text color="base.500" textAlign="center">
|
||||
{t('controlLayers.filter.processingLayerWith', { type: t(`controlLayers.filter.${config.type}.label`) })}
|
||||
</Text>
|
||||
<Text color="base.500" textAlign="center">
|
||||
{t('controlLayers.filter.forMoreControl')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button variant="ghost" onClick={onClickAdvanced}>
|
||||
{t('controlLayers.filter.advanced')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
onClick={adapter.filterer.apply}
|
||||
loadingText={t('controlLayers.filter.apply')}
|
||||
variant="ghost"
|
||||
isDisabled={isProcessing || !isValid || !hasImageState}
|
||||
>
|
||||
{t('controlLayers.filter.apply')}
|
||||
</Button>
|
||||
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContentSimple.displayName = 'FilterContentSimple';
|
||||
|
||||
export const Filter = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
@@ -182,8 +207,54 @@ export const Filter = () => {
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <FilterContent adapter={adapter} />;
|
||||
};
|
||||
|
||||
Filter.displayName = 'Filter';
|
||||
|
||||
const FilterContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const simplified = useStore(adapter.filterer.$simple);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applyFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.apply,
|
||||
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'cancelFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.cancel,
|
||||
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
p={4}
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
w={420}
|
||||
h="auto"
|
||||
shadow="dark-lg"
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
{simplified && <FilterContentSimple adapter={adapter} />}
|
||||
{!simplified && <FilterContentAdvanced adapter={adapter} />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContent.displayName = 'FilterContent';
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import {
|
||||
Badge,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
useToken,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import WavyLine from 'common/components/WavyLine';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
|
||||
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
|
||||
|
||||
export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const isEnabled = useAppSelector(selectIsEnabled);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setImg2imgStrength(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const config = useAppSelector(selectImg2imgStrengthConfig);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
{isEnabled && (
|
||||
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled ? (
|
||||
<>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
variant="outline"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Flex alignItems="center">
|
||||
<Badge opacity="0.6">
|
||||
{t('common.disabled')} - {t('parameters.noRasterLayers')}
|
||||
</Badge>
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ParamDenoisingStrength.displayName = 'ParamDenoisingStrength';
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
|
||||
import {
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isLocked = useEntityIsLocked(entityIdentifier);
|
||||
|
||||
@@ -37,10 +37,10 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
replace: true,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
overrides: { controlAdapter: deepClone(initialControlNet) },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />} isDisabled={isBusy || isLocked}>
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
rasterLayerConvertedToControlLayer,
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const copyToInpaintMask = useCallback(() => {
|
||||
@@ -35,10 +35,10 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
overrides: { controlAdapter: deepClone(initialControlNet) },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />} isDisabled={isBusy}>
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
|
||||
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
|
||||
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
|
||||
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
|
||||
@@ -78,7 +79,7 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
|
||||
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
|
||||
<CanvasEntityAddOfTypeButton type={type} />
|
||||
</Flex>
|
||||
<Collapse in={collapse.isTrue}>
|
||||
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
|
||||
<Flex flexDir="column" gap={2} pt={2}>
|
||||
{children}
|
||||
</Flex>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, chakra, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, chakra, Flex, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { rgbColorToString } from 'common/util/colorCodeTransformers';
|
||||
@@ -86,13 +86,63 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
|
||||
useEffect(updatePreview, [updatePreview, canvasCache, nodeRect, pixelRect]);
|
||||
|
||||
return (
|
||||
<Tooltip label={<TooltipContent canvasRef={canvasRef} />} p={2} closeOnScroll>
|
||||
<Flex
|
||||
position="relative"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
w={CONTAINER_WIDTH_PX}
|
||||
h={CONTAINER_WIDTH_PX}
|
||||
borderRadius="sm"
|
||||
borderWidth={1}
|
||||
bg="base.900"
|
||||
flexShrink={0}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
|
||||
bgSize="5px"
|
||||
/>
|
||||
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
|
||||
|
||||
const TooltipContent = ({ canvasRef }: { canvasRef: React.RefObject<HTMLCanvasElement> }) => {
|
||||
const canvasRef2 = useRef<HTMLCanvasElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!canvasRef2.current || !canvasRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ctx = canvasRef2.current.getContext('2d');
|
||||
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
canvasRef2.current.width = canvasRef.current.width;
|
||||
canvasRef2.current.height = canvasRef.current.height;
|
||||
ctx.clearRect(0, 0, canvasRef2.current.width, canvasRef2.current.height);
|
||||
ctx.drawImage(canvasRef.current, 0, 0);
|
||||
}, [canvasRef]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
w={CONTAINER_WIDTH_PX}
|
||||
h={CONTAINER_WIDTH_PX}
|
||||
w={150}
|
||||
h={150}
|
||||
borderRadius="sm"
|
||||
borderWidth={1}
|
||||
bg="base.900"
|
||||
@@ -105,11 +155,9 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
bottom={0}
|
||||
left={0}
|
||||
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
|
||||
bgSize="5px"
|
||||
bgSize="8px"
|
||||
/>
|
||||
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
|
||||
<ChakraCanvas position="relative" ref={canvasRef2} objectFit="contain" maxW="full" maxH="full" />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
|
||||
};
|
||||
|
||||
@@ -4,9 +4,10 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
|
||||
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import type { CanvasEntityIdentifier, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useMemo, useSyncExternalStore } from 'react';
|
||||
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const EntityAdapterContext = createContext<
|
||||
@@ -95,6 +96,17 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
|
||||
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
|
||||
});
|
||||
|
||||
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
|
||||
type?: T
|
||||
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
|
||||
const adapter = useContext(EntityAdapterContext);
|
||||
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
|
||||
if (type) {
|
||||
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
|
||||
}
|
||||
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
|
||||
};
|
||||
|
||||
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';
|
||||
|
||||
export const useEntityAdapterSafe = (
|
||||
|
||||
@@ -49,6 +49,7 @@ import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'ser
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/** @knipignore */
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -92,11 +93,10 @@ export const selectDefaultIPAdapter = createSelector(
|
||||
|
||||
export const useAddControlLayer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { controlAdapter: defaultControlAdapter };
|
||||
const overrides = { controlAdapter: deepClone(initialControlNet) };
|
||||
dispatch(controlLayerAdded({ isSelected: true, overrides }));
|
||||
}, [defaultControlAdapter, dispatch]);
|
||||
}, [dispatch]);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
@@ -25,7 +25,7 @@ import type {
|
||||
Rect,
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -229,13 +229,12 @@ export const useNewRasterLayerFromBbox = () => {
|
||||
export const useNewControlLayerFromBbox = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
|
||||
const arg = useMemo<UseSaveCanvasArg>(() => {
|
||||
const onSave = (imageDTO: ImageDTO, rect: Rect) => {
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageDTOToImageObject(imageDTO)],
|
||||
controlAdapter: deepClone(defaultControlAdapter),
|
||||
controlAdapter: deepClone(initialControlNet),
|
||||
position: { x: rect.x, y: rect.y },
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
@@ -248,7 +247,7 @@ export const useNewControlLayerFromBbox = () => {
|
||||
toastOk: t('controlLayers.newControlLayerOk'),
|
||||
toastError: t('controlLayers.newControlLayerError'),
|
||||
};
|
||||
}, [defaultControlAdapter, dispatch, t]);
|
||||
}, [dispatch, t]);
|
||||
const func = useSaveCanvas(arg);
|
||||
return func;
|
||||
};
|
||||
|
||||
@@ -329,6 +329,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
entityIdentifiers: T[],
|
||||
deleteMergedEntities: boolean
|
||||
): Promise<ImageDTO | null> => {
|
||||
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergingLayers'), withCount: false });
|
||||
if (entityIdentifiers.length <= 1) {
|
||||
this.log.warn({ entityIdentifiers }, 'Cannot merge less than 2 entities');
|
||||
return null;
|
||||
@@ -351,7 +352,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
|
||||
if (result.isErr()) {
|
||||
this.log.error({ error: serializeError(result.error) }, 'Failed to merge selected entities');
|
||||
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
|
||||
toast({
|
||||
id: 'MERGE_LAYERS_TOAST',
|
||||
title: t('controlLayers.mergeVisibleError'),
|
||||
status: 'error',
|
||||
withCount: false,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -383,7 +389,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
assert<Equals<typeof type, never>>(false, 'Unsupported type for merge');
|
||||
}
|
||||
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergeVisibleOk'), status: 'success', withCount: false });
|
||||
|
||||
return result.value;
|
||||
};
|
||||
|
||||
@@ -83,6 +83,13 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
* Whether the module has an image state. This is a computed value based on $imageState.
|
||||
*/
|
||||
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
|
||||
|
||||
/**
|
||||
* Whether the filter is in simple mode. In simple mode, the filter is started with a default filter config and the
|
||||
* user is not presented with filter settings.
|
||||
*/
|
||||
$simple = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* The filtered image object module, if it exists.
|
||||
*/
|
||||
@@ -147,7 +154,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
/**
|
||||
* Starts the filter module.
|
||||
* @param config The filter config to start with. If omitted, the default filter config is used.
|
||||
* @param config The filter config to use. If omitted, the default filter config is used.
|
||||
*/
|
||||
start = (config?: FilterConfig) => {
|
||||
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
|
||||
@@ -174,12 +181,14 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
// If a config is provided, use it
|
||||
this.$filterConfig.set(config);
|
||||
this.$initialFilterConfig.set(config);
|
||||
this.$simple.set(true);
|
||||
} else {
|
||||
this.$filterConfig.set(this.createInitialFilterConfig());
|
||||
const initialConfig = this.createInitialFilterConfig();
|
||||
this.$filterConfig.set(initialConfig);
|
||||
this.$initialFilterConfig.set(initialConfig);
|
||||
this.$simple.set(false);
|
||||
}
|
||||
|
||||
this.$initialFilterConfig.set(this.$filterConfig.get());
|
||||
|
||||
this.subscribe();
|
||||
|
||||
this.manager.stateApi.$filteringAdapter.set(this.parent);
|
||||
@@ -198,7 +207,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
);
|
||||
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
|
||||
// This always returns a filter
|
||||
const filter = getFilterForModel(modelConfig);
|
||||
const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection;
|
||||
return filter.buildDefaults();
|
||||
} else {
|
||||
// Otherwise, used the default filter
|
||||
@@ -404,7 +413,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
this.imageModule.destroy();
|
||||
this.imageModule = null;
|
||||
}
|
||||
const initialFilterConfig = this.$initialFilterConfig.get() ?? this.createInitialFilterConfig();
|
||||
const initialFilterConfig = deepClone(this.$initialFilterConfig.get() ?? this.createInitialFilterConfig());
|
||||
this.$filterConfig.set(initialFilterConfig);
|
||||
this.$imageState.set(null);
|
||||
this.$lastProcessedHash.set('');
|
||||
|
||||
@@ -456,14 +456,14 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
|
||||
*/
|
||||
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
// No model, use the default filter
|
||||
return IMAGE_FILTERS.canny_edge_detection;
|
||||
// No model
|
||||
return null;
|
||||
}
|
||||
|
||||
const preprocessor = modelConfig?.default_settings?.preprocessor;
|
||||
if (!preprocessor) {
|
||||
// No preprocessor, use the default filter
|
||||
return IMAGE_FILTERS.canny_edge_detection;
|
||||
// No preprocessor
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isFilterType(preprocessor)) {
|
||||
@@ -473,8 +473,8 @@ export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapte
|
||||
|
||||
const filterName = PROCESSOR_TO_FILTER_MAP[preprocessor];
|
||||
if (!filterName) {
|
||||
// No filter found, use the default filter
|
||||
return IMAGE_FILTERS.canny_edge_detection;
|
||||
// No filter found
|
||||
return null;
|
||||
}
|
||||
|
||||
// Found a filter, use it
|
||||
|
||||
@@ -78,8 +78,8 @@ export const initialT2IAdapter: T2IAdapterConfig = {
|
||||
export const initialControlNet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginEndStepPct: [0, 1],
|
||||
weight: 0.75,
|
||||
beginEndStepPct: [0, 0.75],
|
||||
controlMode: 'balanced',
|
||||
};
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
|
||||
aria-label={labelMessage}
|
||||
isDisabled={isDisabled || !isConnected}
|
||||
colorScheme="error"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -22,7 +22,7 @@ import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/us
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOutBold, PiQuestion, PiSwapBold } from 'react-icons/pi';
|
||||
import { PiArrowsLeftRightBold, PiArrowsOutBold, PiQuestion } from 'react-icons/pi';
|
||||
|
||||
export const CompareToolbar = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -60,14 +60,16 @@ export const CompareToolbar = memo(() => {
|
||||
useRegisteredHotkeys({ id: 'nextComparisonMode', category: 'viewer', callback: nextMode, dependencies: [nextMode] });
|
||||
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex w="full" px={2} gap={2} bg="base.750" borderTopRadius="base" h={12}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<Flex marginInlineEnd="auto" alignItems="center">
|
||||
<IconButton
|
||||
icon={<PiSwapBold />}
|
||||
icon={<PiArrowsLeftRightBold />}
|
||||
aria-label={`${t('gallery.swapImages')} (C)`}
|
||||
tooltip={`${t('gallery.swapImages')} (C)`}
|
||||
onClick={swapImages}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
{comparisonMode !== 'side-by-side' && (
|
||||
<IconButton
|
||||
@@ -75,14 +77,15 @@ export const CompareToolbar = memo(() => {
|
||||
tooltip={t('gallery.stretchToFit')}
|
||||
onClick={toggleComparisonFit}
|
||||
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
|
||||
variant="outline"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiArrowsOutBold />}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={4} justifyContent="center">
|
||||
<ButtonGroup variant="outline">
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<ButtonGroup variant="outline" alignItems="center">
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeSlider}
|
||||
@@ -110,11 +113,13 @@ export const CompareToolbar = memo(() => {
|
||||
<Flex gap={2} marginInlineStart="auto" alignItems="center">
|
||||
<Tooltip label={<CompareHelp />}>
|
||||
<Flex alignItems="center">
|
||||
<Icon boxSize={6} color="base.500" as={PiQuestion} lineHeight={0} />
|
||||
<Icon boxSize={6} color="base.300" as={PiQuestion} lineHeight={0} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
<Button
|
||||
variant="ghost"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
px={2}
|
||||
aria-label={`${t('gallery.exitCompare')} (Esc)`}
|
||||
tooltip={`${t('gallery.exitCompare')} (Esc)`}
|
||||
onClick={exitCompare}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { Divider, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -46,73 +46,81 @@ const CurrentImageButtonsContent = memo(({ imageDTO }: { imageDTO: ImageDTO }) =
|
||||
|
||||
return (
|
||||
<>
|
||||
<ButtonGroup>
|
||||
<Menu isLazy>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('parameters.imageActions')}
|
||||
tooltip={t('parameters.imageActions')}
|
||||
isDisabled={!imageDTO}
|
||||
icon={<PiDotsThreeOutlineFill />}
|
||||
/>
|
||||
<MenuList>{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}</MenuList>
|
||||
</Menu>
|
||||
</ButtonGroup>
|
||||
<Menu isLazy>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('parameters.imageActions')}
|
||||
tooltip={t('parameters.imageActions')}
|
||||
isDisabled={!imageDTO}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiDotsThreeOutlineFill />}
|
||||
/>
|
||||
<MenuList>{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}</MenuList>
|
||||
</Menu>
|
||||
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
icon={<PiFlowArrowBold />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
isDisabled={!imageActions.hasWorkflow || !hasTemplates}
|
||||
onClick={imageActions.loadWorkflow}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
tooltip={`${t('parameters.remixImage')} (R)`}
|
||||
aria-label={`${t('parameters.remixImage')} (R)`}
|
||||
isDisabled={!imageActions.hasMetadata}
|
||||
onClick={imageActions.remix}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiQuotesBold />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!imageActions.hasPrompts}
|
||||
onClick={imageActions.recallPrompts}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiPlantBold />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!imageActions.hasSeed}
|
||||
onClick={imageActions.recallSeed}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiRulerBold />}
|
||||
tooltip={`${t('parameters.useSize')} (D)`}
|
||||
aria-label={`${t('parameters.useSize')} (D)`}
|
||||
onClick={imageActions.recallSize}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiAsteriskBold />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={!imageActions.hasMetadata}
|
||||
onClick={imageActions.recallAll}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Divider orientation="vertical" h={8} mx={2} />
|
||||
|
||||
{isUpscalingEnabled && (
|
||||
<ButtonGroup>
|
||||
<PostProcessingPopover imageDTO={imageDTO} />
|
||||
</ButtonGroup>
|
||||
)}
|
||||
<IconButton
|
||||
icon={<PiFlowArrowBold />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
isDisabled={!imageActions.hasWorkflow || !hasTemplates}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.loadWorkflow}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
tooltip={`${t('parameters.remixImage')} (R)`}
|
||||
aria-label={`${t('parameters.remixImage')} (R)`}
|
||||
isDisabled={!imageActions.hasMetadata}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.remix}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiQuotesBold />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!imageActions.hasPrompts}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallPrompts}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiPlantBold />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!imageActions.hasSeed}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallSeed}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiRulerBold />}
|
||||
tooltip={`${t('parameters.useSize')} (D)`}
|
||||
aria-label={`${t('parameters.useSize')} (D)`}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallSize}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiAsteriskBold />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={!imageActions.hasMetadata}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallAll}
|
||||
/>
|
||||
|
||||
<ButtonGroup>
|
||||
<DeleteImageButton onClick={imageActions.delete} />
|
||||
</ButtonGroup>
|
||||
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} />}
|
||||
|
||||
<Divider orientation="vertical" h={8} mx={2} />
|
||||
|
||||
<DeleteImageButton onClick={imageActions.delete} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -37,7 +37,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
|
||||
ref={ref}
|
||||
tabIndex={-1}
|
||||
layerStyle="first"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
position="absolute"
|
||||
flexDirection="column"
|
||||
@@ -51,7 +50,7 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
|
||||
>
|
||||
{hasImageToCompare && <CompareToolbar />}
|
||||
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}
|
||||
<Box ref={containerRef} w="full" h="full">
|
||||
<Box ref={containerRef} w="full" h="full" p={2}>
|
||||
{!hasImageToCompare && <CurrentImagePreview />}
|
||||
{hasImageToCompare && <ImageComparison containerDims={containerDims} />}
|
||||
</Box>
|
||||
@@ -84,7 +83,8 @@ const ImageViewerCloseButton = memo(() => {
|
||||
tooltip={t('gallery.closeViewer')}
|
||||
aria-label={t('gallery.closeViewer')}
|
||||
icon={<PiXBold />}
|
||||
variant="ghost"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageViewer.close}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -38,7 +38,8 @@ export const ToggleMetadataViewerButton = memo(() => {
|
||||
aria-label={`${t('parameters.info')} (I)`}
|
||||
onClick={toggleMetadataViewer}
|
||||
isDisabled={!imageDTO}
|
||||
variant="outline"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
colorScheme={shouldShowImageDetails ? 'invokeBlue' : 'base'}
|
||||
data-testid="toggle-show-metadata-button"
|
||||
/>
|
||||
|
||||
@@ -21,7 +21,8 @@ export const ToggleProgressButton = memo(() => {
|
||||
tooltip={t('settings.displayInProgress')}
|
||||
icon={<PiHourglassHighBold />}
|
||||
onClick={onClick}
|
||||
variant="outline"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
colorScheme={shouldShowProgressInViewer ? 'invokeBlue' : 'base'}
|
||||
data-testid="toggle-show-progress-button"
|
||||
/>
|
||||
|
||||
@@ -12,18 +12,18 @@ type Props = {
|
||||
|
||||
export const ViewerToolbar = memo(({ closeButton }: Props) => {
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex w="full" px={2} gap={2} bg="base.750" borderTopRadius="base" h={12}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<Flex marginInlineEnd="auto" alignItems="center">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<Flex flex={1} justifyContent="center" alignItems="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
<Flex marginInlineStart="auto" alignItems="center">
|
||||
{closeButton}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const marks = [0, 0.5, 1];
|
||||
|
||||
export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setImg2imgStrength(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const config = useAppSelector(selectImg2imgStrengthConfig);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
marks={marks}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ParamDenoisingStrength.displayName = 'ParamDenoisingStrength';
|
||||
@@ -47,6 +47,8 @@ export const PostProcessingPopover = memo((props: Props) => {
|
||||
onClick={onOpen}
|
||||
icon={<PiFrameCornersBold />}
|
||||
aria-label={t('parameters.postProcessing')}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
|
||||
@@ -15,88 +15,87 @@ import { z } from 'zod';
|
||||
* simply be the zod schema's safeParse function
|
||||
*/
|
||||
|
||||
/**
|
||||
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A type guard function for the schema.
|
||||
*/
|
||||
const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper to create a zod schema and a type guard from it.
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A tuple containing the zod schema and the type guard function.
|
||||
*/
|
||||
const buildParameter = <T extends z.ZodTypeAny>(schema: T) => [schema, buildTypeGuard(schema)] as const;
|
||||
|
||||
// #region Positive prompt
|
||||
export const zParameterPositivePrompt = z.string();
|
||||
export const [zParameterPositivePrompt, isParameterPositivePrompt] = buildParameter(z.string());
|
||||
export type ParameterPositivePrompt = z.infer<typeof zParameterPositivePrompt>;
|
||||
export const isParameterPositivePrompt = (val: unknown): val is ParameterPositivePrompt =>
|
||||
zParameterPositivePrompt.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Negative prompt
|
||||
export const zParameterNegativePrompt = z.string();
|
||||
export const [zParameterNegativePrompt, isParameterNegativePrompt] = buildParameter(z.string());
|
||||
export type ParameterNegativePrompt = z.infer<typeof zParameterNegativePrompt>;
|
||||
export const isParameterNegativePrompt = (val: unknown): val is ParameterNegativePrompt =>
|
||||
zParameterNegativePrompt.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Positive style prompt (SDXL)
|
||||
const zParameterPositiveStylePromptSDXL = z.string();
|
||||
export const [zParameterPositiveStylePromptSDXL, isParameterPositiveStylePromptSDXL] = buildParameter(z.string());
|
||||
export type ParameterPositiveStylePromptSDXL = z.infer<typeof zParameterPositiveStylePromptSDXL>;
|
||||
export const isParameterPositiveStylePromptSDXL = (val: unknown): val is ParameterPositiveStylePromptSDXL =>
|
||||
zParameterPositiveStylePromptSDXL.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Positive style prompt (SDXL)
|
||||
const zParameterNegativeStylePromptSDXL = z.string();
|
||||
export const [zParameterNegativeStylePromptSDXL, isParameterNegativeStylePromptSDXL] = buildParameter(z.string());
|
||||
export type ParameterNegativeStylePromptSDXL = z.infer<typeof zParameterNegativeStylePromptSDXL>;
|
||||
export const isParameterNegativeStylePromptSDXL = (val: unknown): val is ParameterNegativeStylePromptSDXL =>
|
||||
zParameterNegativeStylePromptSDXL.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Steps
|
||||
const zParameterSteps = z.number().int().min(1);
|
||||
export const [zParameterSteps, isParameterSteps] = buildParameter(z.number().int().min(1));
|
||||
export type ParameterSteps = z.infer<typeof zParameterSteps>;
|
||||
export const isParameterSteps = (val: unknown): val is ParameterSteps => zParameterSteps.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG scale parameter
|
||||
const zParameterCFGScale = z.number().min(1);
|
||||
export const [zParameterCFGScale, isParameterCFGScale] = buildParameter(z.number().min(1));
|
||||
export type ParameterCFGScale = z.infer<typeof zParameterCFGScale>;
|
||||
export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
||||
zParameterCFGScale.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Guidance parameter
|
||||
const zParameterGuidance = z.number().min(1);
|
||||
export const [zParameterGuidance, isParameterGuidance] = buildParameter(z.number().min(1));
|
||||
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
|
||||
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
|
||||
zParameterGuidance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG Rescale Multiplier
|
||||
const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||
export const [zParameterCFGRescaleMultiplier, isParameterCFGRescaleMultiplier] = buildParameter(
|
||||
z.number().gte(0).lt(1)
|
||||
);
|
||||
export type ParameterCFGRescaleMultiplier = z.infer<typeof zParameterCFGRescaleMultiplier>;
|
||||
export const isParameterCFGRescaleMultiplier = (val: unknown): val is ParameterCFGRescaleMultiplier =>
|
||||
zParameterCFGRescaleMultiplier.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Scheduler
|
||||
const zParameterScheduler = zSchedulerField;
|
||||
export const [zParameterScheduler, isParameterScheduler] = buildParameter(zSchedulerField);
|
||||
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
||||
export const isParameterScheduler = (val: unknown): val is ParameterScheduler =>
|
||||
zParameterScheduler.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region seed
|
||||
const zParameterSeed = z.number().int().min(0).max(NUMPY_RAND_MAX);
|
||||
export const [zParameterSeed, isParameterSeed] = buildParameter(z.number().int().min(0).max(NUMPY_RAND_MAX));
|
||||
export type ParameterSeed = z.infer<typeof zParameterSeed>;
|
||||
export const isParameterSeed = (val: unknown): val is ParameterSeed => zParameterSeed.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Width
|
||||
export const zParameterImageDimension = z
|
||||
.number()
|
||||
.min(64)
|
||||
.transform((val) => roundToMultiple(val, 8));
|
||||
export const [zParameterImageDimension, isParameterImageDimension] = buildParameter(
|
||||
z
|
||||
.number()
|
||||
.min(64)
|
||||
.transform((val) => roundToMultiple(val, 8))
|
||||
);
|
||||
export type ParameterWidth = z.infer<typeof zParameterImageDimension>;
|
||||
export const isParameterWidth = (val: unknown): val is ParameterWidth =>
|
||||
zParameterImageDimension.safeParse(val).success;
|
||||
// #endregion
|
||||
export const isParameterWidth = isParameterImageDimension;
|
||||
|
||||
// #region Height
|
||||
export type ParameterHeight = z.infer<typeof zParameterImageDimension>;
|
||||
export const isParameterHeight = (val: unknown): val is ParameterHeight =>
|
||||
zParameterImageDimension.safeParse(val).success;
|
||||
export const isParameterHeight = isParameterImageDimension;
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
@@ -135,70 +134,50 @@ export type ParameterSpandrelImageToImageModel = z.infer<typeof zParameterSpandr
|
||||
// #endregion
|
||||
|
||||
// #region Strength (l2l strength)
|
||||
const zParameterStrength = z.number().min(0).max(1);
|
||||
export const [zParameterStrength, isParameterStrength] = buildParameter(z.number().min(0).max(1));
|
||||
export type ParameterStrength = z.infer<typeof zParameterStrength>;
|
||||
export const isParameterStrength = (val: unknown): val is ParameterStrength =>
|
||||
zParameterStrength.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SeamlessX
|
||||
const zParameterSeamlessX = z.boolean();
|
||||
export const [zParameterSeamlessX, isParameterSeamlessX] = buildParameter(z.boolean());
|
||||
export type ParameterSeamlessX = z.infer<typeof zParameterSeamlessX>;
|
||||
export const isParameterSeamlessX = (val: unknown): val is ParameterSeamlessX =>
|
||||
zParameterSeamlessX.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SeamlessY
|
||||
const zParameterSeamlessY = z.boolean();
|
||||
export const [zParameterSeamlessY, isParameterSeamlessY] = buildParameter(z.boolean());
|
||||
export type ParameterSeamlessY = z.infer<typeof zParameterSeamlessY>;
|
||||
export const isParameterSeamlessY = (val: unknown): val is ParameterSeamlessY =>
|
||||
zParameterSeamlessY.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Precision
|
||||
const zParameterPrecision = z.enum(['fp16', 'fp32']);
|
||||
export const [zParameterPrecision, isParameterPrecision] = buildParameter(z.enum(['fp16', 'fp32']));
|
||||
export type ParameterPrecision = z.infer<typeof zParameterPrecision>;
|
||||
export const isParameterPrecision = (val: unknown): val is ParameterPrecision =>
|
||||
zParameterPrecision.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region HRF Method
|
||||
const zParameterHRFMethod = z.enum(['ESRGAN', 'bilinear']);
|
||||
export const [zParameterHRFMethod, isParameterHRFMethod] = buildParameter(z.enum(['ESRGAN', 'bilinear']));
|
||||
export type ParameterHRFMethod = z.infer<typeof zParameterHRFMethod>;
|
||||
export const isParameterHRFMethod = (val: unknown): val is ParameterHRFMethod =>
|
||||
zParameterHRFMethod.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region HRF Enabled
|
||||
const zParameterHRFEnabled = z.boolean();
|
||||
export const [zParameterHRFEnabled, isParameterHRFEnabled] = buildParameter(z.boolean());
|
||||
export type ParameterHRFEnabled = z.infer<typeof zParameterHRFEnabled>;
|
||||
export const isParameterHRFEnabled = (val: unknown): val is boolean =>
|
||||
zParameterHRFEnabled.safeParse(val).success && val !== null && val !== undefined;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Positive Aesthetic Score
|
||||
const zParameterSDXLRefinerPositiveAestheticScore = z.number().min(1).max(10);
|
||||
export const [zParameterSDXLRefinerPositiveAestheticScore, isParameterSDXLRefinerPositiveAestheticScore] =
|
||||
buildParameter(z.number().min(1).max(10));
|
||||
export type ParameterSDXLRefinerPositiveAestheticScore = z.infer<typeof zParameterSDXLRefinerPositiveAestheticScore>;
|
||||
export const isParameterSDXLRefinerPositiveAestheticScore = (
|
||||
val: unknown
|
||||
): val is ParameterSDXLRefinerPositiveAestheticScore =>
|
||||
zParameterSDXLRefinerPositiveAestheticScore.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Negative Aesthetic Score
|
||||
const zParameterSDXLRefinerNegativeAestheticScore = zParameterSDXLRefinerPositiveAestheticScore;
|
||||
export const [zParameterSDXLRefinerNegativeAestheticScore, isParameterSDXLRefinerNegativeAestheticScore] =
|
||||
buildParameter(zParameterSDXLRefinerPositiveAestheticScore);
|
||||
export type ParameterSDXLRefinerNegativeAestheticScore = z.infer<typeof zParameterSDXLRefinerNegativeAestheticScore>;
|
||||
export const isParameterSDXLRefinerNegativeAestheticScore = (
|
||||
val: unknown
|
||||
): val is ParameterSDXLRefinerNegativeAestheticScore =>
|
||||
zParameterSDXLRefinerNegativeAestheticScore.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Start
|
||||
const zParameterSDXLRefinerStart = z.number().min(0).max(1);
|
||||
export const [zParameterSDXLRefinerStart, isParameterSDXLRefinerStart] = buildParameter(z.number().min(0).max(1));
|
||||
export type ParameterSDXLRefinerStart = z.infer<typeof zParameterSDXLRefinerStart>;
|
||||
export const isParameterSDXLRefinerStart = (val: unknown): val is ParameterSDXLRefinerStart =>
|
||||
zParameterSDXLRefinerStart.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Mask Blur Method
|
||||
@@ -207,14 +186,13 @@ export type ParameterMaskBlurMethod = z.infer<typeof zParameterMaskBlurMethod>;
|
||||
// #endregion
|
||||
|
||||
// #region Canvas Coherence Mode
|
||||
const zParameterCanvasCoherenceMode = z.enum(['Gaussian Blur', 'Box Blur', 'Staged']);
|
||||
export const [zParameterCanvasCoherenceMode, isParameterCanvasCoherenceMode] = buildParameter(
|
||||
z.enum(['Gaussian Blur', 'Box Blur', 'Staged'])
|
||||
);
|
||||
export type ParameterCanvasCoherenceMode = z.infer<typeof zParameterCanvasCoherenceMode>;
|
||||
export const isParameterCanvasCoherenceMode = (val: unknown): val is ParameterCanvasCoherenceMode =>
|
||||
zParameterCanvasCoherenceMode.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA weight
|
||||
const zLoRAWeight = z.number();
|
||||
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
||||
export const [zLoRAWeight, isParameterLoRAWeight] = buildParameter(z.number());
|
||||
export type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||
// #endregion
|
||||
|
||||
@@ -10,7 +10,6 @@ import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeig
|
||||
import BboxScaledWidth from 'features/parameters/components/Bbox/BboxScaledWidth';
|
||||
import BboxScaleMethod from 'features/parameters/components/Bbox/BboxScaleMethod';
|
||||
import { BboxSettings } from 'features/parameters/components/Bbox/BboxSettings';
|
||||
import { ParamDenoisingStrength } from 'features/parameters/components/Core/ParamDenoisingStrength';
|
||||
import { ParamSeedNumberInput } from 'features/parameters/components/Seed/ParamSeedNumberInput';
|
||||
import { ParamSeedRandomize } from 'features/parameters/components/Seed/ParamSeedRandomize';
|
||||
import { ParamSeedShuffle } from 'features/parameters/components/Seed/ParamSeedShuffle';
|
||||
@@ -76,7 +75,6 @@ export const ImageSettingsAccordion = memo(() => {
|
||||
<ParamSeedShuffle />
|
||||
<ParamSeedRandomize />
|
||||
</Flex>
|
||||
<ParamDenoisingStrength />
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} pb={4} flexDir="column">
|
||||
{isFLUX && <ParamOptimizedDenoisingToggle />}
|
||||
|
||||
@@ -2,27 +2,44 @@ import { ExternalLink, Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||
|
||||
export const WhatsNew = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data } = useGetAppVersionQuery();
|
||||
const isLocal = useAppSelector(selectIsLocal);
|
||||
|
||||
const highlights = useMemo(() => (data?.highlights ? data.highlights : []), [data]);
|
||||
|
||||
return (
|
||||
<Flex gap={4} flexDir="column">
|
||||
<UnorderedList fontSize="sm">
|
||||
<ListItem>
|
||||
<Trans
|
||||
i18nKey="whatsNew.line1"
|
||||
components={{
|
||||
ItalicComponent: <Text as="span" color="white" fontSize="sm" fontStyle="italic" />,
|
||||
}}
|
||||
/>
|
||||
</ListItem>
|
||||
<ListItem>{t('whatsNew.line2')}</ListItem>
|
||||
<ListItem>{t('whatsNew.line3')}</ListItem>
|
||||
{highlights.length ? (
|
||||
highlights.map((highlight, index) => <ListItem key={index}>{highlight}</ListItem>)
|
||||
) : (
|
||||
<>
|
||||
<ListItem>
|
||||
<Trans
|
||||
i18nKey="whatsNew.line1"
|
||||
components={{
|
||||
StrongComponent: <Text as="span" color="white" fontSize="sm" fontWeight="semibold" />,
|
||||
}}
|
||||
/>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans
|
||||
i18nKey="whatsNew.line2"
|
||||
components={{
|
||||
StrongComponent: <Text as="span" color="white" fontSize="sm" fontWeight="semibold" />,
|
||||
}}
|
||||
/>
|
||||
</ListItem>
|
||||
</>
|
||||
)}
|
||||
</UnorderedList>
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<ExternalLink
|
||||
@@ -31,7 +48,7 @@ export const WhatsNew = () => {
|
||||
label={t('whatsNew.readReleaseNotes')}
|
||||
href={
|
||||
isLocal
|
||||
? 'https://github.com/invoke-ai/InvokeAI/releases/tag/v5.0.0'
|
||||
? `https://github.com/invoke-ai/InvokeAI/releases/tag/v${data?.version}`
|
||||
: 'https://support.invoke.ai/support/solutions/articles/151000178246'
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -1687,6 +1687,11 @@ export type components = {
|
||||
* @description App version
|
||||
*/
|
||||
version: string;
|
||||
/**
|
||||
* Highlights
|
||||
* @description Highlights of release
|
||||
*/
|
||||
highlights?: string[] | null;
|
||||
};
|
||||
/**
|
||||
* Apply Tensor Mask to Image
|
||||
|
||||
@@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
@@ -60,16 +60,20 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
|
||||
@pytest.mark.skip(reason="This requires a test model to load")
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_1 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_2 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
Reference in New Issue
Block a user