Compare commits

..

3 Commits

Author SHA1 Message Date
Jonathan
b9ebce9bdd Merge branch 'main' into feat/nodes/invocation-cache 2023-09-18 19:54:14 -05:00
Jonathan
56254b74e2 Prevent a duplicate image from appearing in the gallery
Set final l2i to intermediate so the save image node is the only one outputting a final image.
2023-09-18 12:30:51 -05:00
psychedelicious
593604bbba feat(nodes): add invocation cache
The invocation cache provides simple node memoization functionality. Nodes that use the cache are memoized and not re-executed if their inputs haven't changed. Instead, the stored output is returned.

## Results

This feature provides anywhere some significant to massive performance improvement.

The improvement is most marked on large batches of generations where you only change a couple things (e.g. different seed or prompt for each iteration) and low-VRAM systems, where skipping an extraneous model load is a big deal.

## Overview

A new `invocation_cache` service is added to handle the caching. There's not much to it.

All nodes now inherit a boolean `use_cache` field from `BaseInvocation`. This is a node field and not a class attribute, because specific instances of nodes may want to opt in or out of caching.

The recently-added `invoke_internal()` method on `BaseInvocation` is used as an entrypoint for the cache logic.

To create a cache key, the invocation is first serialized using pydantic's provided `json()` method, skipping the unique `id` field. Then python's very fast builtin `hash()` is used to create an integer key. All implementations of `InvocationCacheBase` must provide a class method `create_key()` which accepts an invocation and outputs a string or integer key.

## In-Memory Implementation

An in-memory implementation is provided. In this implementation, the node outputs are stored in memory as python classes. The in-memory cache does not persist application restarts.

Max node cache size is added as `node_cache_size` under the `Generation` config category.

It defaults to 512 - this number is up for discussion, but given that these are relatively lightweight pydantic models, I think it's safe to up this even higher.

Note that the cache isn't storing the big stuff - tensors and images are store on disk, and outputs include only references to them.

## Node Definition

The default for all nodes is to use the cache. The `@invocation` decorator now accepts an optional `use_cache: bool` argument to override the default of `True`.

Non-deterministic nodes, however, should set this to `False`. Currently, all random-stuff nodes, including `dynamic_prompt`, are set to `False`.

The field name `use_cache` is now effectively a reserved field name and possibly a breaking change if any community nodes use this as a field name. In hindsight, all our reserved field names should have been prefixed with underscores or something.

## One Gotcha

Leaf nodes probably want to opt out of the cache, because if they are not cached, their outputs are not saved again.

If you run the same graph multiple times, you only end up with a single image output, because the image storage side-effects are in the `invoke()` method, which is bypassed if we have a cache hit.

## Linear UI

The linear graphs _almost_ just work, but due to the gotcha, we need to be careful about the final image-outputting node. To resolve this, a `SaveImageInvocation` node is added and used in the linear graphs.

This node is similar to `ImagePrimitive`, except it saves a copy of its input image, and has `use_cache` set to `False` by default.

This is now the leaf node in all linear graphs, and is the only node in those graphs with `use_cache == False` _and_ the only node with `is_intermedate == False`.

## Workflow Editor

All nodes now have a footer with a new `Use Cache [ ]` checkbox. It defaults to the value set by the invocation in its python definition, but can be changed by the user.

The workflow/node validation logic has been updated to migrate old workflows to use the new default values for `use_cache`. Users may still want to review the settings that have been chosen. In the event of catastrophic failure when running this migration, the default value of `True` is applied, as this is correct for most nodes.

Users should consider saving their workflows after loading them in and having them updated.

## Future Enhancements - Callback

A future enhancement would be to provide a callback to the `use_cache` flag that would be run as the node is executed to determine, based on its own internal state, if the cache should be used or not.

This would be useful for `DynamicPromptInvocation`, where the deterministic behaviour is determined by the `combinatorial: bool` field.

## Future Enhancements - Persisted Cache

Similar to how the latents storage is backed by disk, the invocation cache could be persisted to the database or disk. We'd need to be very careful about deserializing outputs, but it's perhaps worth exploring in the future.
2023-09-18 13:41:19 +10:00
111 changed files with 2577 additions and 4235 deletions

View File

@@ -9,6 +9,7 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
@@ -126,6 +127,7 @@ class ApiDependencies:
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
create_system_graphs(services.graph_library)

View File

@@ -1,5 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
@@ -309,6 +311,7 @@ def invoke_cli():
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
system_graphs = create_system_graphs(services.graph_library)

View File

@@ -67,7 +67,6 @@ class FieldDescriptions:
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
@@ -156,7 +155,6 @@ class UIType(str, Enum):
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
@@ -570,7 +568,24 @@ class BaseInvocation(ABC, BaseModel):
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
output: BaseInvocationOutput
if self.use_cache:
key = context.services.invocation_cache.create_key(self)
cached_value = context.services.invocation_cache.get(key)
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
context.services.invocation_cache.save(key, output)
return output
else:
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
return cached_value
else:
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
def get_type(self) -> str:
return self.__fields__["type"].default
id: str = Field(
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
@@ -583,6 +598,7 @@ class BaseInvocation(ABC, BaseModel):
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
@@ -606,6 +622,7 @@ def invocation(
tags: Optional[list[str]] = None,
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
@@ -638,6 +655,8 @@ def invocation(
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
if use_cache is not None:
cls.__fields__["use_cache"].default = use_cache
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore

View File

@@ -56,6 +56,7 @@ class RangeOfSizeInvocation(BaseInvocation):
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@@ -7,14 +7,14 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
BasicConditioningInfo,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
@@ -99,15 +99,14 @@ class CompelInvocation(BaseInvocation):
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
):
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -123,7 +122,7 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = ExtraConditioningInfo(
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@@ -214,15 +213,14 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
):
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -246,7 +244,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
ec = ExtraConditioningInfo(
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@@ -438,11 +436,9 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [
(
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
)
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
for x in parsed_prompt.children
]
text = " ".join(text_fragments)

View File

@@ -965,3 +965,42 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
width=image_dto.width,
height=image_dto.height,
)
@invocation(
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.0.0",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description="The image to load")
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -1,105 +0,0 @@
import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=0, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)

View File

@@ -8,7 +8,6 @@ import numpy as np
import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -20,7 +19,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
@@ -33,17 +31,15 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor,
)
@@ -72,6 +68,7 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@@ -194,7 +191,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.1.0",
version="1.0.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@@ -222,12 +219,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
)
@validator("cfg_scale")
@@ -329,6 +323,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@@ -348,107 +344,57 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@@ -542,12 +488,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.dict(),
context=context,
)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, **self.unet.seamless.dict()),
unet_info as unet,
):
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@@ -566,7 +509,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
controlnet_data = self.prep_control_data(
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=latents.shape,
@@ -575,14 +519,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler,
device=unet.device,
@@ -601,8 +537,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # IPAdapterData,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
@@ -648,7 +583,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context,
)
with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae:
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

View File

@@ -54,7 +54,14 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
@invocation(
"rand_int",
title="Random Integer",
tags=["math", "random"],
category="math",
version="1.0.0",
use_cache=False,
)
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""

View File

@@ -18,13 +18,6 @@ from .baseinvocation import (
)
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
@@ -40,7 +33,7 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ClipField(BaseModel):
@@ -53,7 +46,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("model_loader_output")
@@ -395,11 +388,6 @@ class SeamlessModeInvocation(BaseInvocation):
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@@ -414,18 +402,8 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y")
if unet is not None:
unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
unet.seamless_axes = seamless_axes_list
if vae is not None:
vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@@ -95,10 +95,9 @@ class ONNXPromptInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list:
text_encoder.release_session()
with (
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
):
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
orig_tokenizer, text_encoder, ti_list
) as (tokenizer, ti_manager):
text_encoder.create_session()
# copy from

View File

@@ -10,7 +10,14 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
@invocation(
"dynamic_prompt",
title="Dynamic Prompt",
tags=["prompt", "collection"],
category="prompt",
version="1.0.0",
use_cache=False,
)
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""

View File

@@ -253,6 +253,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Generation", )
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")

View File

@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
class InvocationCacheBase(ABC):
"""Base class for invocation caches."""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves and invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleted an invocation output from the cache"""
pass
@classmethod
@abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
"""Creates the cache key for an invocation"""
pass

View File

@@ -0,0 +1,34 @@
from queue import Queue
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput]
__max_cache_size: int
__cache_ids: Queue
def __init__(self, max_cache_size: int = 512) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
return self.__cache.get(key, None)
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
if key not in self.__cache:
self.__cache[key] = value
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
def delete(self, key: Union[int, str]) -> None:
if key in self.__cache:
del self.__cache[key]
@classmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
return hash(value.json(exclude={"id"}))

View File

@@ -12,6 +12,7 @@ if TYPE_CHECKING:
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.invoker import InvocationProcessorABC
@@ -37,6 +38,7 @@ class InvocationServices:
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
invocation_cache: "InvocationCacheBase"
def __init__(
self,
@@ -53,6 +55,7 @@ class InvocationServices:
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
invocation_cache: "InvocationCacheBase",
):
self.board_images = board_images
self.boards = boards
@@ -68,3 +71,4 @@ class InvocationServices:
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.invocation_cache = invocation_cache

View File

@@ -326,16 +326,6 @@ class ModelInstall(object):
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
@@ -544,17 +534,14 @@ def hf_download_with_resume(
logger.info(f"{model_name}: Downloading...")
try:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)

View File

@@ -1,45 +0,0 @@
# IP-Adapter Model Formats
The official IP-Adapter models are released here: [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
This official model repo does not integrate well with InvokeAI's current approach to model management, so we have defined a new file structure for IP-Adapter models. The InvokeAI format is described below.
## CLIP Vision Models
CLIP Vision models are organized in `diffusers`` format. The expected directory structure is:
```bash
ip_adapter_sd_image_encoder/
├── config.json
└── model.safetensors
```
## IP-Adapter Models
IP-Adapter models are stored in a directory containing two files
- `image_encoder.txt`: A text file containing the model identifier for the CLIP Vision encoder that is intended to be used with this IP-Adapter model.
- `ip_adapter.bin`: The IP-Adapter weights.
Sample directory structure:
```bash
ip_adapter_sd15/
├── image_encoder.txt
└── ip_adapter.bin
```
### Why save the weights in a .safetensors file?
The weights in `ip_adapter.bin` are stored in a nested dict, which is not supported by `safetensors`. This could be solved by splitting `ip_adapter.bin` into multiple files, but for now we have decided to maintain consistency with the checkpoint structure used in the official [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repo.
## InvokeAI Hosted IP-Adapters
Image Encoders:
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)
IP-Adapters:
- [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@@ -1,162 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@@ -1,217 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
clip_extra_context_tokens (int, optional): Defaults to 4.
Returns:
ImageProjModel
"""
cross_attention_dim = state_dict["norm.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class IPAdapter:
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(
self,
state_dict: dict[torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
):
self.device = device
self.dtype = dtype
self._num_tokens = num_tokens
self._clip_image_processor = CLIPImageProcessor()
self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
heads=12,
num_queries=self._num_tokens,
ff_mult=4,
).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
# contains.
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@@ -1,158 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# tencent ailab comment: modified from
# https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
possible if needed in the future.
Args:
state_dict (dict[torch.Tensor]): The state_dict to load.
depth (int, optional):
dim_head (int, optional):
heads (int, optional):
ff_mult (int, optional):
Returns:
Resampler
"""
dim = state_dict["latents"].shape[2]
num_queries = state_dict["latents"].shape[1]
embedding_dim = state_dict["proj_in.weight"].shape[-1]
output_dim = state_dict["norm_out.weight"].shape[0]
model = cls(
dim=dim,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=ff_mult,
)
model.load_state_dict(state_dict)
return model
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)

View File

@@ -25,7 +25,6 @@ Models are described using four attributes:
ModelType.Lora -- a LoRA or LyCORIS fine-tune
ModelType.TextualInversion -- a textual inversion embedding
ModelType.ControlNet -- a ControlNet model
ModelType.IPAdapter -- an IPAdapter model
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
BaseModelType.StableDiffusion1
@@ -1001,8 +1000,8 @@ class ModelManager(object):
new_models_found = True
except DuplicateModelException as e:
self.logger.warning(e)
except InvalidModelException as e:
self.logger.warning(f"Not a valid model: {model_path}. {e}")
except InvalidModelException:
self.logger.warning(f"Not a valid model: {model_path}")
except NotImplementedError as e:
self.logger.warning(e)

View File

@@ -8,8 +8,6 @@ import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
@@ -54,7 +52,6 @@ class ModelProbe(object):
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@classmethod
@@ -121,18 +118,14 @@ class ModelProbe(object):
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512,
)
except Exception:
raise
@@ -184,10 +177,9 @@ class ModelProbe(object):
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
@@ -196,12 +188,7 @@ class ModelProbe(object):
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
class_name = conf["_class_name"]
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
@@ -379,16 +366,6 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -508,13 +485,11 @@ class ControlNetFolderProbe(FolderProbeBase):
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
@@ -534,47 +509,15 @@ class LoRAFolderProbe(FolderProbeBase):
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -79,7 +79,7 @@ class ModelSearch(ABC):
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
@@ -90,7 +90,7 @@ class ModelSearch(ABC):
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
self.logger.warning(str(e))
class FindModels(ModelSearch):

View File

@@ -18,9 +18,7 @@ from .base import ( # noqa: F401
SilenceWarnings,
SubModelType,
)
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO:
from .ip_adapter import IPAdapterModel
from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
@@ -36,8 +34,6 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
@@ -46,8 +42,6 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
@@ -57,8 +51,6 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
@@ -68,19 +60,6 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
# The following model types are not expected to be used with BaseModelType.Any.
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,

View File

@@ -36,7 +36,6 @@ class ModelNotFoundException(Exception):
class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
@@ -51,8 +50,6 @@ class ModelType(str, Enum):
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
class SubModelType(str, Enum):

View File

@@ -1,82 +0,0 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@@ -1,92 +0,0 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model

View File

@@ -25,55 +25,71 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try:
to_restore = []
for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if resnet_num > 0 and skip_second_resnet:
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if submodule_name == "conv2" and skip_conv2:
if False and ".downsamplers." in m_name:
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield

View File

@@ -1,6 +1,15 @@
"""
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusers_pipeline import ( # noqa: F401
ConditioningData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
BasicConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import math
from contextlib import nullcontext
from dataclasses import dataclass
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Union
import einops
@@ -23,11 +23,9 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
@dataclass
@@ -97,7 +95,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
for k, v in step_output.items()
}
)
@@ -164,13 +162,39 @@ class ControlNetData:
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
weight: Union[float, List[float]] = Field(default=1.0)
# weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass
@@ -253,7 +277,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -326,7 +349,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@@ -378,7 +400,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
callback=callback,
)
finally:
@@ -398,7 +419,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@@ -411,26 +431,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model,
scale=weight,
)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
with attn_ctx:
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
if callback is not None:
callback(
PipelineIntermediateState(
@@ -453,7 +459,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
)
latents = step_output.prev_sample
@@ -499,7 +504,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@@ -510,24 +514,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
weight = (
ip_adapter_data.weight[step_index]
if isinstance(ip_adapter_data.weight, List)
else ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
ip_adapter_data.ip_adapter_model.set_scale(weight)
else:
# otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0)
# handle ControlNet(s)
# default is no controlnet, so set controlnet processing output to None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None:

View File

@@ -3,4 +3,9 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
from .shared_invokeai_diffusion import ( # noqa: F401
BasicConditioningInfo,
InvokeAIDiffuserComponent,
PostprocessingSettings,
SDXLConditioningInfo,
)

View File

@@ -1,101 +0,0 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
"""
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)

View File

@@ -376,11 +376,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ "work properly until it is fixed."
)
return attention_module_tuples
@@ -577,7 +577,6 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import math
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
@@ -9,14 +10,9 @@ from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .cross_attention_control import (
Arguments,
Context,
CrossAttentionType,
SwapCrossAttnContext,
@@ -35,6 +31,37 @@ ModelForwardCallback: TypeAlias = Union[
]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@@ -48,6 +75,15 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(
self,
model,
@@ -67,26 +103,30 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
):
old_attn_processors = unet.attn_processors
old_attn_processors = None
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
try:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
@@ -336,24 +376,11 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
}
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
@@ -381,7 +408,6 @@ class InvokeAIDiffuserComponent:
x_twice,
sigma_twice,
both_conditionings,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
@@ -393,12 +419,9 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
conditioning_data,
**kwargs,
):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
# low-memory sequential path
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
@@ -414,13 +437,6 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
}
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
@@ -433,21 +449,12 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
}
added_cond_kwargs = None
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
@@ -458,7 +465,6 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{v as m,hj as Je,u as y,Y as Xa,hk as Ja,a7 as ua,ab as d,hl as b,hm as o,hn as Qa,ho as h,hp as fa,hq as Za,hr as eo,aE as ro,hs as ao,a4 as oo,ht as to}from"./index-f6c3f475.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-c9cc8c3d.js";var Ca=String.raw,Aa=Ca`
import{v as m,h5 as Je,u as y,Y as Xa,h6 as Ja,a7 as ua,ab as d,h7 as b,h8 as o,h9 as Qa,ha as h,hb as fa,hc as Za,hd as eo,aE as ro,he as ao,a4 as oo,hf as to}from"./index-f83c2c5c.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-31376327.js";var Ca=String.raw,Aa=Ca`
:root,
:host {
--chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-f6c3f475.js"></script>
<script type="module" crossorigin src="./assets/index-f83c2c5c.js"></script>
</head>
<body dir="ltr">

File diff suppressed because it is too large Load Diff

View File

@@ -49,7 +49,6 @@
"close": "Close",
"communityLabel": "Community",
"controlNet": "Controlnet",
"ipAdapter": "IP Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -192,11 +191,7 @@
"showAdvanced": "Show Advanced",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight",
"enableIPAdapter": "Enable IP Adapter",
"ipAdapterModel": "Adapter Model",
"resetIPAdapterImage": "Reset IP Adapter Image",
"ipAdapterImageFallback": "No IP Adapter Image Selected"
"weight": "Weight"
},
"embedding": {
"addEmbedding": "Add Embedding",
@@ -1041,7 +1036,6 @@
"serverError": "Server Error",
"setCanvasInitialImage": "Set as canvas initial image",
"setControlImage": "Set as control image",
"setIPAdapterImage": "Set as IP Adapter Image",
"setInitialImage": "Set as initial image",
"setNodeField": "Set as node field",
"tempFoldersEmptied": "Temp Folder Emptied",

View File

@@ -1,8 +1,5 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
@@ -21,7 +18,6 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wasControlNetReset = false;
let wasIPAdapterReset = false;
const state = getState();
deleted_images.forEach((image_name) => {
@@ -46,11 +42,6 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
dispatch(controlNetReset());
wasControlNetReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
}
});
},
});

View File

@@ -3,7 +3,6 @@ import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetProcessedImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
@@ -111,14 +110,6 @@ export const addRequestedSingleImageDeletionListener = () => {
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
@@ -236,14 +227,6 @@ export const addRequestedMultipleImageDeletionListener = () => {
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {

View File

@@ -1,11 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
@@ -18,6 +14,7 @@ import {
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
@@ -102,18 +99,6 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
return;
}
/**
* Image dropped on Canvas
*/

View File

@@ -19,7 +19,6 @@ export const addImageToDeleteSelectedListener = () => {
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {

View File

@@ -1,18 +1,15 @@
import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { startAppListening } from '..';
import { imagesApi } from '../../../../../services/api/endpoints/images';
import { t } from 'i18next';
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
title: t('toast.imageUploaded'),
@@ -102,17 +99,6 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO));
dispatch(

View File

@@ -1,9 +1,6 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
controlNetRemoved,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -59,7 +56,6 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}
// handle incompatible controlnets
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
@@ -68,16 +64,6 @@ export const addModelSelectedListener = () => {
}
});
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@@ -0,0 +1,5 @@
import { Store } from '@reduxjs/toolkit';
import { atom } from 'nanostores';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const $store = atom<Store<any> | undefined>();

View File

@@ -31,6 +31,7 @@ import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { $store } from './nanostores/store';
const allReducers = {
canvas: canvasReducer,
@@ -86,10 +87,7 @@ export const store = configureStore({
.concat(autoBatchEnhancer());
},
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
})
getDefaultMiddleware({ immutableCheck: false })
.concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
@@ -124,3 +122,4 @@ export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;
$store.set(store);

View File

@@ -18,7 +18,6 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
@@ -29,6 +28,7 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import { useTranslation } from 'react-i18next';
type ControlNetProps = {
controlNet: ControlNetConfig;

View File

@@ -1,35 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParamIPAdapterBeginEnd from './ParamIPAdapterBeginEnd';
import ParamIPAdapterFeatureToggle from './ParamIPAdapterFeatureToggle';
import ParamIPAdapterImage from './ParamIPAdapterImage';
import ParamIPAdapterModelSelect from './ParamIPAdapterModelSelect';
import ParamIPAdapterWeight from './ParamIPAdapterWeight';
const IPAdapterPanel = () => {
return (
<Flex
sx={{
flexDir: 'column',
gap: 3,
paddingInline: 3,
paddingBlock: 2,
paddingBottom: 5,
borderRadius: 'base',
position: 'relative',
bg: 'base.250',
_dark: {
bg: 'base.750',
},
}}
>
<ParamIPAdapterFeatureToggle />
<ParamIPAdapterImage />
<ParamIPAdapterModelSelect />
<ParamIPAdapterWeight />
<ParamIPAdapterBeginEnd />
</Flex>
);
};
export default memo(IPAdapterPanel);

View File

@@ -1,100 +0,0 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamIPAdapterBeginEnd = () => {
const isEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const beginStepPct = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.beginStepPct
);
const endStepPct = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.endStepPct
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(ipAdapterBeginStepPctChanged(v[0] as number));
dispatch(ipAdapterEndStepPctChanged(v[1] as number));
},
[dispatch]
);
return (
<FormControl isDisabled={!isEnabled}>
<FormLabel>{t('controlnet.beginEndStepPercent')}</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
aria-label={['Begin Step %', 'End Step %!']}
value={[beginStepPct, endStepPct]}
onChange={handleStepPctChanged}
min={0}
max={1}
step={0.01}
minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
100%
</RangeSliderMark>
</RangeSlider>
</HStack>
</FormControl>
);
};
export default memo(ParamIPAdapterBeginEnd);

View File

@@ -1,41 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isIPAdapterEnableToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { isIPAdapterEnabled } = state.controlNet;
return { isIPAdapterEnabled };
},
defaultSelectorOptions
);
const ParamIPAdapterFeatureToggle = () => {
const { isIPAdapterEnabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(() => {
dispatch(isIPAdapterEnableToggled());
}, [dispatch]);
return (
<IAISwitch
label={t('controlnet.enableIPAdapter')}
isChecked={isIPAdapterEnabled}
onChange={handleChange}
formControlProps={{
width: '100%',
}}
/>
);
};
export default memo(ParamIPAdapterFeatureToggle);

View File

@@ -1,93 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { ipAdapterImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const ParamIPAdapterImage = () => {
const ipAdapterInfo = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo
);
const isIPAdapterEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { currentData: imageDTO } = useGetImageDTOQuery(
ipAdapterInfo.adapterImage?.image_name ?? skipToken
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (imageDTO) {
return {
id: 'ip-adapter-image',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'ip-adapter-image',
actionType: 'SET_IP_ADAPTER_IMAGE',
}),
[]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({
type: 'SET_IP_ADAPTER_IMAGE',
}),
[]
);
return (
<Flex
sx={{
position: 'relative',
w: 'full',
alignItems: 'center',
justifyContent: 'center',
}}
>
<IAIDndImage
imageDTO={imageDTO}
droppableData={droppableData}
draggableData={draggableData}
postUploadAction={postUploadAction}
isUploadDisabled={!isIPAdapterEnabled}
isDropDisabled={!isIPAdapterEnabled}
dropLabel={t('toast.setIPAdapterImage')}
noContentFallback={
<IAINoContentFallback
label={t('controlnet.ipAdapterImageFallback')}
/>
}
/>
<IAIDndImageIcon
onClick={() => dispatch(ipAdapterImageChanged(null))}
icon={ipAdapterInfo.adapterImage ? <FaUndo /> : undefined}
tooltip={t('controlnet.resetIPAdapterImage')}
/>
</Flex>
);
};
export default memo(ParamIPAdapterImage);

View File

@@ -1,97 +0,0 @@
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { ipAdapterModelChanged } from 'features/controlNet/store/controlNetSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const ParamIPAdapterModelSelect = () => {
const ipAdapterModel = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.model
);
const model = useAppSelector((state: RootState) => state.generation.model);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
ipAdapterModels?.entities[
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
] ?? null,
[
ipAdapterModel?.base_model,
ipAdapterModel?.model_name,
ipAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!ipAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(ipAdapterModels.entities, (ipAdapterModel, id) => {
if (!ipAdapterModel) {
return;
}
const disabled = model?.base_model !== ipAdapterModel.base_model;
data.push({
value: id,
label: ipAdapterModel.model_name,
group: MODEL_TYPE_MAP[ipAdapterModel.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${ipAdapterModel.base_model}`
: undefined,
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [ipAdapterModels, model?.base_model]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
if (!newIPAdapterModel) {
return;
}
dispatch(ipAdapterModelChanged(newIPAdapterModel));
},
[dispatch]
);
return (
<IAIMantineSelect
label={t('controlnet.ipAdapterModel')}
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(ParamIPAdapterModelSelect);

View File

@@ -1,46 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { ipAdapterWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const ParamIPAdapterWeight = () => {
const isIpAdapterEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const ipAdapterWeight = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.weight
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(ipAdapterWeightChanged(weight));
},
[dispatch]
);
const handleWeightReset = useCallback(() => {
dispatch(ipAdapterWeightChanged(1));
}, [dispatch]);
return (
<IAISlider
isDisabled={!isIpAdapterEnabled}
label={t('controlnet.weight')}
value={ipAdapterWeight}
onChange={handleWeightChanged}
min={0}
max={2}
step={0.01}
withSliderMarks
sliderMarks={[0, 1, 2]}
withReset
handleReset={handleWeightReset}
/>
);
};
export default memo(ParamIPAdapterWeight);

View File

@@ -1,13 +1,9 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
@@ -60,36 +56,16 @@ export type ControlNetConfig = {
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
adapterImage: ImageDTO | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
pendingControlImages: string[];
isIPAdapterEnabled: boolean;
ipAdapterInfo: IPAdapterConfig;
};
export const initialIPAdapterState: IPAdapterConfig = {
adapterImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
isEnabled: false,
pendingControlImages: [],
isIPAdapterEnabled: false,
ipAdapterInfo: { ...initialIPAdapterState },
};
export const controlNetSlice = createSlice({
@@ -377,31 +353,6 @@ export const controlNetSlice = createSlice({
controlNetReset: () => {
return { ...initialControlNetState };
},
isIPAdapterEnableToggled: (state) => {
state.isIPAdapterEnabled = !state.isIPAdapterEnabled;
},
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
},
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.weight = action.payload;
},
ipAdapterModelChanged: (
state,
action: PayloadAction<IPAdapterModelParam | null>
) => {
state.ipAdapterInfo.model = action.payload;
},
ipAdapterBeginStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.beginStepPct = action.payload;
},
ipAdapterEndStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.endStepPct = action.payload;
},
ipAdapterStateReset: (state) => {
state.isIPAdapterEnabled = false;
state.ipAdapterInfo = { ...initialIPAdapterState };
},
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
@@ -461,13 +412,6 @@ export const {
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnableToggled,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
ipAdapterStateReset,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@@ -10,20 +10,20 @@ import {
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { stateSelector } from 'app/store/store';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { imageDeletionConfirmed } from '../store/actions';
import { getImageUsage, selectImageUsage } from '../store/selectors';
import { imageDeletionCanceled, isModalOpenChanged } from '../store/slice';
import { ImageUsage } from '../store/types';
import ImageUsageMessage from './ImageUsageMessage';
import { ImageUsage } from '../store/types';
const selector = createSelector(
[stateSelector, selectImageUsage],
@@ -42,7 +42,6 @@ const selector = createSelector(
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
};
return {

View File

@@ -1,8 +1,8 @@
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { some } from 'lodash-es';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ImageUsage } from '../store/types';
import { useTranslation } from 'react-i18next';
type Props = {
imageUsage?: ImageUsage;
@@ -38,9 +38,6 @@ const ImageUsageMessage = (props: Props) => {
{imageUsage.isControlNetImage && (
<ListItem>{t('common.controlNet')}</ListItem>
)}
{imageUsage.isIPAdapterImage && (
<ListItem>{t('common.ipAdapter')}</ListItem>
)}
{imageUsage.isNodesImage && (
<ListItem>{t('common.nodeEditor')}</ListItem>
)}

View File

@@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
@@ -27,15 +27,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
c.controlImage === image_name || c.processedControlImage === image_name
);
const isIPAdapterImage =
controlNet.ipAdapterInfo.adapterImage?.image_name === image_name;
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
isIPAdapterImage,
};
return imageUsage;

View File

@@ -10,5 +10,4 @@ export type ImageUsage = {
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
isIPAdapterImage: boolean;
};

View File

@@ -35,10 +35,6 @@ export type ControlNetDropData = BaseDropData & {
};
};
export type IPAdapterImageDropData = BaseDropData & {
actionType: 'SET_IP_ADAPTER_IMAGE';
};
export type CanvasInitialImageDropData = BaseDropData & {
actionType: 'SET_CANVAS_INITIAL_IMAGE';
};
@@ -77,7 +73,6 @@ export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
| ControlNetDropData
| IPAdapterImageDropData
| CanvasInitialImageDropData
| NodesImageDropData
| AddToBatchDropData

View File

@@ -24,8 +24,6 @@ export const isValidDrop = (
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_IP_ADAPTER_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':

View File

@@ -53,7 +53,6 @@ const DeleteBoardModal = (props: Props) => {
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
};
return { imageUsageSummary };
}),

View File

@@ -27,7 +27,7 @@ const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"

View File

@@ -1,14 +1,13 @@
import { Flex, Grid, GridItem } from '@chakra-ui/react';
import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames';
import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { memo } from 'react';
import NodeWrapper from '../common/NodeWrapper';
import InvocationNodeFooter from './InvocationNodeFooter';
import InvocationNodeHeader from './InvocationNodeHeader';
import NodeWrapper from '../common/NodeWrapper';
import OutputField from './fields/OutputField';
import InputField from './fields/InputField';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames';
import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames';
import OutputField from './fields/OutputField';
type Props = {
nodeId: string;
@@ -22,7 +21,6 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
const outputFieldNames = useOutputFieldNames(nodeId);
const withFooter = useWithFooter(nodeId);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
@@ -43,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
h: 'full',
py: 2,
gap: 1,
borderBottomRadius: withFooter ? 0 : 'base',
borderBottomRadius: 0,
}}
>
<Flex sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}>
@@ -76,7 +74,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
))}
</Flex>
</Flex>
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
<InvocationNodeFooter nodeId={nodeId} />
</>
)}
</NodeWrapper>

View File

@@ -3,12 +3,15 @@ import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
import UseCacheCheckbox from './UseCacheCheckbox';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
type Props = {
nodeId: string;
};
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
@@ -22,8 +25,9 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
justifyContent: 'space-between',
}}
>
<EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveToGalleryCheckbox nodeId={nodeId} />
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
<UseCacheCheckbox nodeId={nodeId} />
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
</Flex>
);
};

View File

@@ -0,0 +1,35 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useUseCache } from 'features/nodes/hooks/useUseCache';
import { nodeUseCacheChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const UseCacheCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const useCache = useUseCache(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeUseCacheChanged({
nodeId,
useCache: e.target.checked,
})
);
},
[dispatch, nodeId]
);
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Use Cache</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={useCache}
/>
</FormControl>
);
};
export default memo(UseCacheCheckbox);

View File

@@ -15,7 +15,6 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
type InputFieldProps = {
nodeId: string;
@@ -148,19 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'IPAdapterModelField' &&
fieldTemplate?.type === 'IPAdapterModelField'
) {
return (
<IPAdapterModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField

View File

@@ -1,17 +0,0 @@
import {
IPAdapterInputFieldTemplate,
IPAdapterInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const IPAdapterInputFieldComponent = (
_props: FieldComponentProps<
IPAdapterInputFieldValue,
IPAdapterInputFieldTemplate
>
) => {
return null;
};
export default memo(IPAdapterInputFieldComponent);

View File

@@ -1,100 +0,0 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
IPAdapterModelInputFieldTemplate,
IPAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const IPAdapterModelInputFieldComponent = (
props: FieldComponentProps<
IPAdapterModelInputFieldValue,
IPAdapterModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const ipAdapterModel = field.value;
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
ipAdapterModels?.entities[
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
] ?? null,
[
ipAdapterModel?.base_model,
ipAdapterModel?.model_name,
ipAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!ipAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(ipAdapterModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [ipAdapterModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
if (!newIPAdapterModel) {
return;
}
dispatch(
fieldIPAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value: newIPAdapterModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(IPAdapterModelInputFieldComponent);

View File

@@ -146,6 +146,7 @@ export const useBuildNodeData = () => {
isIntermediate: true,
inputs,
outputs,
useCache: template.useCache,
},
};

View File

@@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useUseCache = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
// cast to boolean to support older workflows that didn't have useCache
// TODO: handle this better somehow
return node.data.useCache;
},
defaultSelectorOptions
),
[nodeId]
);
const useCache = useAppSelector(selector);
return useCache;
};

View File

@@ -7,7 +7,7 @@ import { useMemo } from 'react';
import { FOOTER_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
export const useWithFooter = (nodeId: string) => {
export const useHasImageOutputs = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(

View File

@@ -41,7 +41,6 @@ import {
IntegerInputFieldValue,
InvocationNodeData,
InvocationTemplate,
IPAdapterModelInputFieldValue,
isInvocationNode,
isNotesNode,
LoRAModelInputFieldValue,
@@ -261,6 +260,20 @@ const nodesSlice = createSlice({
}
node.data.embedWorkflow = embedWorkflow;
},
nodeUseCacheChanged: (
state,
action: PayloadAction<{ nodeId: string; useCache: boolean }>
) => {
const { nodeId, useCache } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.useCache = useCache;
},
nodeIsIntermediateChanged: (
state,
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
@@ -521,12 +534,6 @@ const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action);
},
fieldIPAdapterModelValueChanged: (
state,
action: FieldValueAction<IPAdapterModelInputFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: (
state,
action: FieldValueAction<EnumInputFieldValue>
@@ -873,7 +880,6 @@ export const {
fieldLoRAModelValueChanged,
fieldEnumModelValueChanged,
fieldControlNetModelValueChanged,
fieldIPAdapterModelValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
nodeIsOpenChanged,
@@ -912,6 +918,7 @@ export const {
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
nodeUseCacheChanged,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@@ -41,7 +41,6 @@ export const POLYMORPHIC_TYPES = [
];
export const MODEL_TYPES = [
'IPAdapterModelField',
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
@@ -237,16 +236,6 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: t('nodes.integerPolymorphicDescription'),
title: t('nodes.integerPolymorphic'),
},
IPAdapterField: {
color: 'green.300',
description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter',
},
IPAdapterModelField: {
color: 'teal.500',
description: 'IP-Adapter model',
title: 'IP-Adapter Model',
},
LatentsCollection: {
color: 'pink.500',
description: t('nodes.latentsCollectionDescription'),

View File

@@ -1,3 +1,4 @@
import { $store } from 'app/store/nanostores/store';
import {
SchedulerParam,
zBaseModel,
@@ -7,7 +8,8 @@ import {
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
import { keyBy } from 'lodash-es';
import i18n from 'i18next';
import { has, keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow';
@@ -20,7 +22,6 @@ import {
import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod';
import i18n from 'i18next';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@@ -57,6 +58,10 @@ export type InvocationTemplate = {
* The invocation's version.
*/
version?: string;
/**
* Whether or not this node should use the cache
*/
useCache: boolean;
};
export type FieldUIConfig = {
@@ -94,8 +99,6 @@ export const zFieldType = z.enum([
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'IPAdapterField',
'IPAdapterModelField',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
@@ -391,25 +394,6 @@ export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zIPAdapterModel = zModelIdentifier;
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModel,
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterField'),
value: zIPAdapterField.optional(),
});
export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
'main',
@@ -559,17 +543,6 @@ export type ControlNetModelInputFieldValue = z.infer<
typeof zControlNetModelInputFieldValue
>;
export const zIPAdapterModelField = zModelIdentifier;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterModelField'),
value: zIPAdapterModelField.optional(),
});
export type IPAdapterModelInputFieldValue = z.infer<
typeof zIPAdapterModelInputFieldValue
>;
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Collection'),
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
@@ -652,8 +625,6 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zIPAdapterInputFieldValue,
zIPAdapterModelInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
@@ -856,11 +827,6 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
type: 'ControlPolymorphic';
};
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'IPAdapterField';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'enum';
@@ -898,11 +864,6 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlNetModelField';
};
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'IPAdapterModelField';
};
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'Collection';
@@ -974,8 +935,6 @@ export type InputFieldTemplate =
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| IPAdapterInputFieldTemplate
| IPAdapterModelInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate
@@ -1022,6 +981,9 @@ export type InvocationSchemaExtra = {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: AnyInvocationType;
};
use_cache: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: boolean;
};
};
};
@@ -1185,9 +1147,37 @@ export const zInvocationNodeData = z.object({
version: zSemVer.optional(),
});
export const zInvocationNodeDataV2 = z.preprocess(
(arg) => {
try {
const data = zInvocationNodeData.parse(arg);
if (!has(data, 'useCache')) {
const nodeTemplates = $store.get()?.getState().nodes.nodeTemplates as
| Record<string, InvocationTemplate>
| undefined;
const template = nodeTemplates?.[data.type];
let useCache = true;
if (template) {
useCache = template.useCache;
}
Object.assign(data, { useCache });
}
return data;
} catch {
return arg;
}
},
zInvocationNodeData.extend({
useCache: z.boolean(),
})
);
// Massage this to get better type safety while developing
export type InvocationNodeData = Omit<
z.infer<typeof zInvocationNodeData>,
z.infer<typeof zInvocationNodeDataV2>,
'type'
> & {
type: AnyInvocationType;
@@ -1215,7 +1205,7 @@ const zDimension = z.number().gt(0).nullish();
export const zWorkflowInvocationNode = z.object({
id: z.string().trim().min(1),
type: z.literal('invocation'),
data: zInvocationNodeData,
data: zInvocationNodeDataV2,
width: zDimension,
height: zDimension,
position: zPosition,
@@ -1277,6 +1267,8 @@ export type WorkflowWarning = {
data: JsonObject;
};
const CURRENT_WORKFLOW_VERSION = '1.0.0';
export const zWorkflow = z.object({
name: z.string().default(''),
author: z.string().default(''),
@@ -1292,7 +1284,7 @@ export const zWorkflow = z.object({
.object({
version: zSemVer,
})
.default({ version: '1.0.0' }),
.default({ version: CURRENT_WORKFLOW_VERSION }),
});
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {

View File

@@ -60,8 +60,6 @@ import {
ImageField,
LatentsField,
ConditioningField,
IPAdapterInputFieldTemplate,
IPAdapterModelInputFieldTemplate,
} from '../types/types';
import { ControlField } from 'services/api/types';
@@ -437,19 +435,6 @@ const buildControlNetModelInputFieldTemplate = ({
return template;
};
const buildIPAdapterModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
const template: IPAdapterModelInputFieldTemplate = {
...baseField,
type: 'IPAdapterModelField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@@ -663,19 +648,6 @@ const buildControlCollectionInputFieldTemplate = ({
return template;
};
const buildIPAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
const template: IPAdapterInputFieldTemplate = {
...baseField,
type: 'IPAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@@ -879,8 +851,6 @@ const TEMPLATE_BUILDER_MAP = {
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
IPAdapterField: buildIPAdapterInputFieldTemplate,
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,

View File

@@ -28,8 +28,6 @@ const FIELD_VALUE_FALLBACK_MAP = {
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
IPAdapterField: undefined,
IPAdapterModelField: undefined,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,

View File

@@ -1,59 +0,0 @@
import { RootState } from 'app/store/store';
import { IPAdapterInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { IP_ADAPTER } from './constants';
export const addIPAdapterToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
// | MetadataAccumulatorInvocation
// | undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER,
type: 'ip_adapter',
is_intermediate: true,
weight: ipAdapterInfo.weight,
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
};
if (ipAdapterInfo.adapterImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage.image_name,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
// if (metadataAccumulator?.ip_adapters) {
// // metadata accumulator only needs the ip_adapter field - not the whole node
// // extract what we need and add to the accumulator
// const ipAdapterField = omit(ipAdapterNode, [
// 'id',
// 'type',
// ]) as IPAdapterField;
// metadataAccumulator.ip_adapters.push(ipAdapterField);
// }
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: baseNodeId,
field: 'ip_adapter',
},
});
}
};

View File

@@ -1,46 +1,32 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
} from './constants';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate,
is_intermediate: true,
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
@@ -54,17 +40,4 @@ export const addNSFWCheckerToGraph = (
field: 'image',
},
});
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: NSFW_CHECKER,
field: 'metadata',
},
});
}
};

View File

@@ -0,0 +1,92 @@
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
SAVE_IMAGE,
WATERMARKER,
} from './constants';
import {
MetadataAccumulatorInvocation,
SaveImageInvocation,
} from 'services/api/types';
import { RootState } from 'app/store/store';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
/**
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
*/
export const addSaveImageNode = (
state: RootState,
graph: NonNullableGraph
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const saveImageNode: SaveImageInvocation = {
id: SAVE_IMAGE,
type: 'save_image',
is_intermediate,
use_cache: false,
};
graph.nodes[SAVE_IMAGE] = saveImageNode;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: SAVE_IMAGE,
field: 'metadata',
},
});
}
const destination = {
node_id: SAVE_IMAGE,
field: 'image',
};
if (WATERMARKER in graph.nodes) {
graph.edges.push({
source: {
node_id: WATERMARKER,
field: 'image',
},
destination,
});
} else if (NSFW_CHECKER in graph.nodes) {
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination,
});
} else if (CANVAS_OUTPUT in graph.nodes) {
graph.edges.push({
source: {
node_id: CANVAS_OUTPUT,
field: 'image',
},
destination,
});
} else if (LATENTS_TO_IMAGE in graph.nodes) {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination,
});
}
};

View File

@@ -51,6 +51,7 @@ export const addWatermarkerToGraph = (
// no matter the situation, we want the l2i node to be intermediate
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
if (nsfwCheckerNode) {
// if we are using NSFW checker, we need to "disable" it output by marking it intermediate,

View File

@@ -5,7 +5,6 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -26,6 +25,7 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Image to Image graph.
@@ -54,14 +54,10 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -93,31 +89,31 @@ export const buildCanvasImageToImageGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate: true,
is_intermediate,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: negativePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
is_intermediate,
use_cpu,
width: !isUsingScaledDimensions
? width
@@ -129,12 +125,12 @@ export const buildCanvasImageToImageGraph = (
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
is_intermediate,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -144,7 +140,7 @@ export const buildCanvasImageToImageGraph = (
[CANVAS_OUTPUT]: {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
},
},
edges: [
@@ -239,7 +235,7 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[IMG2IMG_RESIZE] = {
id: IMG2IMG_RESIZE,
type: 'img_resize',
is_intermediate: true,
is_intermediate,
image: initialImage,
width: scaledBoundingBoxDimensions.width,
height: scaledBoundingBoxDimensions.height,
@@ -247,13 +243,13 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
is_intermediate: true,
is_intermediate,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate: !shouldAutoSave,
is_intermediate,
width: width,
height: height,
};
@@ -294,7 +290,7 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
fp32,
};
@@ -338,17 +334,6 @@ export const buildCanvasImageToImageGraph = (
init_image: initialImage.image_name,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'metadata',
},
});
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
@@ -367,9 +352,6 @@ export const buildCanvasImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -381,5 +363,7 @@ export const buildCanvasImageToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -12,7 +12,6 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -45,6 +44,7 @@ import {
RANGE_OF_SIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Inpaint graph.
@@ -88,12 +88,8 @@ export const buildCanvasInpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const is_intermediate = true;
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(
@@ -112,56 +108,56 @@ export const buildCanvasInpaintGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate: true,
is_intermediate,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: negativePrompt,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate: true,
is_intermediate,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -172,18 +168,18 @@ export const buildCanvasInpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -193,19 +189,19 @@ export const buildCanvasInpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
reference: canvasInitImage,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
is_intermediate,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -214,7 +210,7 @@ export const buildCanvasInpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
is_intermediate,
},
},
edges: [
@@ -437,7 +433,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -445,7 +441,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasMaskImage,
@@ -453,14 +449,14 @@ export const buildCanvasInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
@@ -598,7 +594,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
};
@@ -651,7 +647,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate: true,
is_intermediate,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -737,9 +733,6 @@ export const buildCanvasInpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -751,5 +744,7 @@ export const buildCanvasInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -11,7 +11,6 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -47,6 +46,7 @@ import {
RANGE_OF_SIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Outpaint graph.
@@ -92,14 +92,10 @@ export const buildCanvasOutpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -116,61 +112,61 @@ export const buildCanvasOutpaintGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate: true,
is_intermediate,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: negativePrompt,
},
[MASK_FROM_ALPHA]: {
type: 'tomask',
id: MASK_FROM_ALPHA,
is_intermediate: true,
is_intermediate,
image: canvasInitImage,
},
[MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
is_intermediate,
mask2: canvasMaskImage,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -181,18 +177,18 @@ export const buildCanvasOutpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -202,18 +198,18 @@ export const buildCanvasOutpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
is_intermediate,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -222,7 +218,7 @@ export const buildCanvasOutpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
is_intermediate,
},
},
edges: [
@@ -473,7 +469,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
downscale: infillPatchmatchDownscaleSize,
};
}
@@ -482,7 +478,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_lama',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
};
}
@@ -490,7 +486,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_cv2',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
};
}
@@ -498,7 +494,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_tile',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
tile_size: infillTileSize,
};
}
@@ -512,7 +508,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -520,28 +516,28 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[INPAINT_INFILL_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_INFILL_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
@@ -700,7 +696,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
};
@@ -747,7 +743,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate: true,
is_intermediate,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -839,9 +835,6 @@ export const buildCanvasOutpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -853,5 +846,7 @@ export const buildCanvasOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -5,7 +5,6 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -28,6 +27,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Image to Image graph.
@@ -62,14 +62,10 @@ export const buildCanvasSDXLImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -123,7 +119,7 @@ export const buildCanvasSDXLImageToImageGraph = (
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
is_intermediate,
use_cpu,
width: !isUsingScaledDimensions
? width
@@ -135,13 +131,13 @@ export const buildCanvasSDXLImageToImageGraph = (
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
is_intermediate,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -252,7 +248,7 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[IMG2IMG_RESIZE] = {
id: IMG2IMG_RESIZE,
type: 'img_resize',
is_intermediate: true,
is_intermediate,
image: initialImage,
width: scaledBoundingBoxDimensions.width,
height: scaledBoundingBoxDimensions.height,
@@ -260,13 +256,13 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
is_intermediate: true,
is_intermediate,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate: !shouldAutoSave,
is_intermediate,
width: width,
height: height,
};
@@ -307,7 +303,7 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
fp32,
};
@@ -393,9 +389,6 @@ export const buildCanvasSDXLImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -407,5 +400,7 @@ export const buildCanvasSDXLImageToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -46,7 +46,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Inpaint graph.
@@ -95,14 +95,10 @@ export const buildCanvasSDXLInpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -140,32 +136,32 @@ export const buildCanvasSDXLInpaintGraph = (
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate: true,
is_intermediate,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -178,18 +174,18 @@ export const buildCanvasSDXLInpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -199,19 +195,19 @@ export const buildCanvasSDXLInpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
reference: canvasInitImage,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
is_intermediate,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -220,7 +216,7 @@ export const buildCanvasSDXLInpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
is_intermediate,
},
},
edges: [
@@ -452,7 +448,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -460,7 +456,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasMaskImage,
@@ -468,14 +464,14 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
@@ -613,7 +609,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
};
@@ -666,7 +662,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate: true,
is_intermediate,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -766,9 +762,6 @@ export const buildCanvasSDXLInpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -780,5 +773,7 @@ export const buildCanvasSDXLInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -11,7 +11,6 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -49,6 +48,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Outpaint graph.
@@ -99,14 +99,10 @@ export const buildCanvasSDXLOutpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -144,37 +140,37 @@ export const buildCanvasSDXLOutpaintGraph = (
[MASK_FROM_ALPHA]: {
type: 'tomask',
id: MASK_FROM_ALPHA,
is_intermediate: true,
is_intermediate,
image: canvasInitImage,
},
[MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
is_intermediate,
mask2: canvasMaskImage,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -187,18 +183,18 @@ export const buildCanvasSDXLOutpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate: true,
is_intermediate,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -208,18 +204,18 @@ export const buildCanvasSDXLOutpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
is_intermediate,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
is_intermediate,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -228,7 +224,7 @@ export const buildCanvasSDXLOutpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
is_intermediate,
},
},
edges: [
@@ -488,7 +484,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
downscale: infillPatchmatchDownscaleSize,
};
}
@@ -497,7 +493,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_lama',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
};
}
@@ -505,7 +501,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_cv2',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
};
}
@@ -513,7 +509,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_tile',
id: INPAINT_INFILL,
is_intermediate: true,
is_intermediate,
tile_size: infillTileSize,
};
}
@@ -527,7 +523,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -535,28 +531,28 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate: true,
is_intermediate,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[INPAINT_INFILL_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_INFILL_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate: true,
is_intermediate,
width: width,
height: height,
};
@@ -716,7 +712,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate: true,
is_intermediate,
fp32,
};
@@ -763,7 +759,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate: true,
is_intermediate,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -869,9 +865,6 @@ export const buildCanvasSDXLOutpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -883,5 +876,7 @@ export const buildCanvasSDXLOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,7 +8,6 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -30,6 +29,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Text to Image graph.
@@ -56,14 +56,10 @@ export const buildCanvasSDXLTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -95,7 +91,7 @@ export const buildCanvasSDXLTextToImageGraph = (
? {
type: 't2l_onnx',
id: SDXL_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -103,7 +99,7 @@ export const buildCanvasSDXLTextToImageGraph = (
: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -132,27 +128,27 @@ export const buildCanvasSDXLTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
is_intermediate,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,
@@ -254,14 +250,14 @@ export const buildCanvasSDXLTextToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
is_intermediate: true,
is_intermediate,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate: !shouldAutoSave,
is_intermediate,
width: width,
height: height,
};
@@ -292,7 +288,7 @@ export const buildCanvasSDXLTextToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
fp32,
};
@@ -373,9 +369,6 @@ export const buildCanvasSDXLTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -387,5 +380,7 @@ export const buildCanvasSDXLTextToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,9 +8,9 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
@@ -54,14 +54,10 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -90,7 +86,7 @@ export const buildCanvasTextToImageGraph = (
? {
type: 't2l_onnx',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -98,7 +94,7 @@ export const buildCanvasTextToImageGraph = (
: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -123,31 +119,31 @@ export const buildCanvasTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate: true,
is_intermediate,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
is_intermediate,
prompt: negativePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
is_intermediate,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,
@@ -240,14 +236,14 @@ export const buildCanvasTextToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
is_intermediate: true,
is_intermediate,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate: !shouldAutoSave,
is_intermediate,
width: width,
height: height,
};
@@ -278,7 +274,7 @@ export const buildCanvasTextToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
is_intermediate,
fp32,
};
@@ -346,9 +342,6 @@ export const buildCanvasTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -360,5 +353,7 @@ export const buildCanvasTextToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,7 +8,6 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -28,6 +27,7 @@ import {
RESIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Image to Image tab graph.
@@ -86,6 +86,7 @@ export const buildLinearImageToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
@@ -101,31 +102,37 @@ export const buildLinearImageToImageGraph = (
type: 'main_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers: clipSkip,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -135,6 +142,7 @@ export const buildLinearImageToImageGraph = (
steps,
denoising_start: 1 - strength,
denoising_end: 1,
is_intermediate,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
@@ -144,6 +152,7 @@ export const buildLinearImageToImageGraph = (
// image_name: initialImage.image_name,
// },
fp32,
is_intermediate,
},
},
edges: [
@@ -365,9 +374,6 @@ export const buildLinearImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -379,5 +385,7 @@ export const buildLinearImageToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,7 +8,6 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -30,6 +29,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Image to Image tab graph.
@@ -86,6 +86,7 @@ export const buildLinearSDXLImageToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
@@ -106,28 +107,33 @@ export const buildLinearSDXLImageToImageGraph = (
type: 'sdxl_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -139,6 +145,7 @@ export const buildLinearSDXLImageToImageGraph = (
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
is_intermediate,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
@@ -148,6 +155,7 @@ export const buildLinearSDXLImageToImageGraph = (
// image_name: initialImage.image_name,
// },
fp32,
is_intermediate,
},
},
edges: [
@@ -385,9 +393,6 @@ export const buildLinearSDXLImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
@@ -402,5 +407,7 @@ export const buildLinearSDXLImageToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -4,7 +4,6 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -24,6 +23,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
export const buildLinearSDXLTextToImageGraph = (
state: RootState
@@ -57,13 +57,13 @@ export const buildLinearSDXLTextToImageGraph = (
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Construct Style Prompt
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
@@ -89,18 +89,21 @@ export const buildLinearSDXLTextToImageGraph = (
type: 'sdxl_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
@@ -108,6 +111,7 @@ export const buildLinearSDXLTextToImageGraph = (
width,
height,
use_cpu,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -117,11 +121,13 @@ export const buildLinearSDXLTextToImageGraph = (
steps,
denoising_start: 0,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
},
edges: [
@@ -278,9 +284,6 @@ export const buildLinearSDXLTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
@@ -295,5 +298,7 @@ export const buildLinearSDXLTextToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,7 +8,6 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -27,6 +26,7 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
export const buildLinearTextToImageGraph = (
state: RootState
@@ -59,7 +59,7 @@ export const buildLinearTextToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingOnnxModel = model.model_type === 'onnx';
let modelLoaderNodeId = isUsingOnnxModel
@@ -75,7 +75,7 @@ export const buildLinearTextToImageGraph = (
? {
type: 't2l_onnx',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -83,7 +83,7 @@ export const buildLinearTextToImageGraph = (
: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
is_intermediate,
cfg_scale,
scheduler,
steps,
@@ -109,26 +109,26 @@ export const buildLinearTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate: true,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers: clipSkip,
is_intermediate: true,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
is_intermediate: true,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
is_intermediate: true,
is_intermediate,
},
[NOISE]: {
type: 'noise',
@@ -136,13 +136,14 @@ export const buildLinearTextToImageGraph = (
width,
height,
use_cpu,
is_intermediate: true,
is_intermediate,
},
[t2lNode.id]: t2lNode,
[LATENTS_TO_IMAGE]: {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
},
edges: [
@@ -283,9 +284,6 @@ export const buildLinearTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -297,5 +295,7 @@ export const buildLinearTextToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -55,6 +55,9 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
{} as Record<Exclude<string, 'id' | 'type'>, unknown>
);
// add reserved use_cache
transformedInputs['use_cache'] = node.data.useCache;
// Build this specific node
const graphNode = {
type,

Some files were not shown because too many files have changed in this diff Show More