mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 10:48:03 -05:00
Compare commits
73 Commits
v3.0.0+b8
...
onnx-testi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8299d0abb | ||
|
|
a28ab654ef | ||
|
|
8699fd7050 | ||
|
|
9e65470ada | ||
|
|
f4e52fafac | ||
|
|
ee7b36cea5 | ||
|
|
487455ef2e | ||
|
|
632346b2e2 | ||
|
|
e201ad2f51 | ||
|
|
0f18898865 | ||
|
|
700131fab2 | ||
|
|
634d6bb8a6 | ||
|
|
2fbf245c3d | ||
|
|
39c14eb2ac | ||
|
|
c89fa4b635 | ||
|
|
e943913f58 | ||
|
|
4ada094c5c | ||
|
|
893e199677 | ||
|
|
3c5a0c95b3 | ||
|
|
71a07ee5a7 | ||
|
|
2a0a765ec4 | ||
|
|
ec08151009 | ||
|
|
186e98da5e | ||
|
|
dea9a5da7a | ||
|
|
bda0000acd | ||
|
|
4b678f2416 | ||
|
|
869f418b03 | ||
|
|
35d5ef9118 | ||
|
|
42c440c73f | ||
|
|
416afd2781 | ||
|
|
afa84a149c | ||
|
|
be659364c2 | ||
|
|
56098f370c | ||
|
|
bcce70fca6 | ||
|
|
932112b640 | ||
|
|
91112167b1 | ||
|
|
bd7b59910d | ||
|
|
524888bf3b | ||
|
|
0327eae509 | ||
|
|
bb85608890 | ||
|
|
6c7668aaca | ||
|
|
7759b3f75a | ||
|
|
4d337f6abc | ||
|
|
92c86fd0b8 | ||
|
|
46dc751139 | ||
|
|
4cefe37723 | ||
|
|
82b73c50a0 | ||
|
|
7df7a95299 | ||
|
|
85b4b359c2 | ||
|
|
cfe81b5e00 | ||
|
|
b0c4451324 | ||
|
|
d4931522d4 | ||
|
|
17e2a35228 | ||
|
|
91016d8b29 | ||
|
|
9fda21cf40 | ||
|
|
809ec7163e | ||
|
|
7c9a939b47 | ||
|
|
9634c96020 | ||
|
|
e0c105f413 | ||
|
|
f0bf32c476 | ||
|
|
28373dbb98 | ||
|
|
4133d77772 | ||
|
|
61c426f502 | ||
|
|
bf0577c882 | ||
|
|
24673fd859 | ||
|
|
dc669d1447 | ||
|
|
ce4110b9f4 | ||
|
|
0f3b7d2b3d | ||
|
|
16dc78f6c6 | ||
|
|
7a66856785 | ||
|
|
c8dfa49d86 | ||
|
|
76dd749b1e | ||
|
|
67d05d2066 |
@@ -1,6 +1,14 @@
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
|
||||
@@ -518,8 +518,8 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
@@ -54,6 +54,7 @@ class MainModelField(BaseModel):
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
@@ -221,6 +222,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
# TODO: ui rewrite
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
|
||||
591
invokeai/app/invocations/onnx.py
Normal file
591
invokeai/app/invocations/onnx.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.util import choose_torch_device
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from tqdm import tqdm
|
||||
from .model import ClipField
|
||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||
from .compel import CompelOutput
|
||||
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
#stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
#)
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
text_encoder.release_session()
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(np.float16)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
#latents_dtype = prompt_embeds.dtype
|
||||
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
#if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
unet.create_session()
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
import time
|
||||
times = []
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
start_time = time.time()
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
times.append(time.time() - start_time)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id= "test",
|
||||
step=i,
|
||||
timestep=timestep,
|
||||
latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(
|
||||
self,
|
||||
context=context,
|
||||
source_node_id=source_node_id,
|
||||
intermediate_state=state
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
print(times)
|
||||
unet.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
# Latent to image
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
vae.release_session()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
#fmt: on
|
||||
|
||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
|
||||
model_name = "stable-diffusion-v1-5"
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {model_name}!")
|
||||
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path, PosixPath
|
||||
from typing import Literal, Union, cast
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
@@ -16,19 +16,20 @@ from .image import ImageOutput
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
REALESRGAN_MODELS = Literal[
|
||||
ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x4plus.pth",
|
||||
"RealESRGAN_x4plus_anime_6B.pth",
|
||||
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
|
||||
class RealESRGANInvocation(BaseInvocation):
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
type: Literal["realesrgan"] = "realesrgan"
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: REALESRGAN_MODELS = Field(
|
||||
model_name: ESRGAN_MODELS = Field(
|
||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||
)
|
||||
|
||||
@@ -73,19 +74,17 @@ class RealESRGANInvocation(BaseInvocation):
|
||||
scale=4,
|
||||
)
|
||||
netscale = 4
|
||||
# TODO: add x2 models handling?
|
||||
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
||||
# # x2 RRDBNet model
|
||||
# model = RRDBNet(
|
||||
# num_in_ch=3,
|
||||
# num_out_ch=3,
|
||||
# num_feat=64,
|
||||
# num_block=23,
|
||||
# num_grow_ch=32,
|
||||
# scale=2,
|
||||
# )
|
||||
# model_path = Path()
|
||||
# netscale = 2
|
||||
elif self.model_name in ["RealESRGAN_x2plus.pth"]:
|
||||
# x2 RRDBNet model
|
||||
rrdbnet_model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
netscale = 2
|
||||
else:
|
||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||
context.services.logger.error(msg)
|
||||
|
||||
@@ -466,7 +466,6 @@ class Generator:
|
||||
dtype=samples.dtype,
|
||||
device=samples.device,
|
||||
)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
|
||||
@@ -222,7 +222,7 @@ def download_conversion_models():
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
logger.info("Installing RealESRGAN models...")
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
@@ -239,6 +239,11 @@ def download_realesrgan():
|
||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description = "RealESRGAN_x2plus.pth",
|
||||
),
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||
|
||||
@@ -10,7 +10,7 @@ from tempfile import TemporaryDirectory
|
||||
from typing import List, Dict, Callable, Union, Set
|
||||
|
||||
import requests
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
@@ -310,6 +310,8 @@ class ModelInstall(object):
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
elif location.is_dir():
|
||||
return location.name
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
@@ -365,7 +367,7 @@ class ModelInstall(object):
|
||||
model = None
|
||||
for revision in revisions:
|
||||
try:
|
||||
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
pass
|
||||
if model:
|
||||
|
||||
@@ -3,6 +3,7 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
||||
@@ -6,11 +6,22 @@ from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: rename and split this file
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
#alpha: Optional[float]
|
||||
@@ -708,3 +719,185 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = dict()
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(
|
||||
np.copy(orig_embeddings),
|
||||
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
||||
@@ -328,6 +328,25 @@ class ModelCache(object):
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
#break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||
|
||||
@@ -363,6 +382,9 @@ class ModelCache(object):
|
||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||
vram_in_use += mem.vram_used # note vram_used is negative
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
||||
@@ -106,16 +106,16 @@ providing information about a model defined in models.yaml. For example:
|
||||
|
||||
>>> models = mgr.list_models()
|
||||
>>> json.dumps(models[0])
|
||||
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||
"model_format": "diffusers",
|
||||
"name": "canny",
|
||||
"base_model": "sd-1",
|
||||
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||
"model_format": "diffusers",
|
||||
"name": "canny",
|
||||
"base_model": "sd-1",
|
||||
"type": "controlnet"
|
||||
}
|
||||
|
||||
You can filter by model type and base model as shown here:
|
||||
|
||||
|
||||
|
||||
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
|
||||
base_model=BaseModelType.StableDiffusion1)
|
||||
for c in controlnets:
|
||||
@@ -140,14 +140,14 @@ Layout of the `models` directory:
|
||||
|
||||
models
|
||||
├── sd-1
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
├── sd-2
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ ├── controlnet
|
||||
│ ├── lora
|
||||
│ ├── main
|
||||
│ └── embedding
|
||||
└── core
|
||||
├── face_reconstruction
|
||||
@@ -195,7 +195,7 @@ name, base model, type and a dict of model attributes. See
|
||||
`invokeai/backend/model_management/models` for the attributes required
|
||||
by each model type.
|
||||
|
||||
A model can be deleted using `del_model()`, providing the same
|
||||
A model can be deleted using `del_model()`, providing the same
|
||||
identifying information as `get_model()`
|
||||
|
||||
The `heuristic_import()` method will take a set of strings
|
||||
@@ -304,7 +304,7 @@ class ModelManager(object):
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
@@ -323,7 +323,7 @@ class ModelManager(object):
|
||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
@@ -431,7 +431,7 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@@ -456,7 +456,7 @@ class ModelManager(object):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
@@ -489,7 +489,7 @@ class ModelManager(object):
|
||||
self.cache_keys[model_key].add(model_context.key)
|
||||
|
||||
model_hash = "<NO_HASH>" # TODO:
|
||||
|
||||
|
||||
return ModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
@@ -518,7 +518,7 @@ class ModelManager(object):
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||
known to the configuration.
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
@@ -692,12 +692,12 @@ class ModelManager(object):
|
||||
if new_name is None and new_base is None:
|
||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||
return
|
||||
|
||||
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.get(model_key, None)
|
||||
if not model_cfg:
|
||||
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
||||
|
||||
|
||||
old_path = self.app_config.root_path / model_cfg.path
|
||||
new_name = new_name or model_name
|
||||
new_base = new_base or base_model
|
||||
@@ -726,7 +726,7 @@ class ModelManager(object):
|
||||
self.models.pop(model_key, None) # delete
|
||||
self.models[new_key] = model_cfg
|
||||
self.commit()
|
||||
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -776,12 +776,12 @@ class ModelManager(object):
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@@ -824,10 +824,14 @@ class ModelManager(object):
|
||||
assert config_file_path is not None,'no config file path to write to'
|
||||
config_file_path = self.app_config.root_path / config_file_path
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile, config_file_path)
|
||||
try:
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.replace(tmpfile, config_file_path)
|
||||
except OSError as err:
|
||||
self.logger.warning(f"Could not modify the config file at {config_file_path}")
|
||||
self.logger.warning(err)
|
||||
|
||||
def preamble(self) -> str:
|
||||
"""
|
||||
@@ -977,13 +981,12 @@ class ModelManager(object):
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
successfully_installed = dict()
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class ModelProbeInfo(object):
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal['diffusers','checkpoint', 'lycoris']
|
||||
format: Literal['diffusers','checkpoint', 'lycoris', 'olive']
|
||||
image_size: int
|
||||
|
||||
class ProbeBase(object):
|
||||
|
||||
@@ -10,8 +10,11 @@ from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||
ModelType.Main: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@@ -19,6 +22,7 @@ MODEL_CLASSES = {
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@@ -32,6 +36,7 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@@ -40,6 +45,7 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
|
||||
@@ -8,13 +8,19 @@ from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
@@ -29,6 +35,7 @@ class BaseModelType(str, Enum):
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
@@ -42,6 +49,8 @@ class SubModelType(str, Enum):
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
@@ -254,16 +263,18 @@ class DiffusersModel(ModelBase):
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
os.path.join(self.model_path, child_type.value),
|
||||
#subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
@@ -430,3 +441,188 @@ class SilenceWarnings(object):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.model.data[key].numpy()
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray):
|
||||
new_node = numpy_helper.from_array(value)
|
||||
# set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
# new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.model.data
|
||||
|
||||
def items(self):
|
||||
raise NotImplementedError("tensor.items")
|
||||
#return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.model.data.keys()
|
||||
|
||||
def values(self):
|
||||
raise NotImplementedError("tensor.values")
|
||||
#return [obj for obj in self.raw_proto]
|
||||
|
||||
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.indexes
|
||||
|
||||
def items(self):
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider or "CPUExecutionProvider"
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
self.data = dict()
|
||||
for tensor in self.proto.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
self.data[name] = orv
|
||||
# set_external_data(tensor, location="in-memory-location")
|
||||
tensor.name = name
|
||||
# tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self):
|
||||
if self.session is None:
|
||||
#onnx.save(self.proto, "tmp.onnx")
|
||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
#self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
# sess.enable_profiling = True
|
||||
|
||||
# sess.intra_op_num_threads = 1
|
||||
# sess.inter_op_num_threads = 1
|
||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.enable_cpu_mem_arena = True
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
|
||||
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
output_names = self.session.get_outputs()
|
||||
for k in inputs:
|
||||
self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
for name in output_names:
|
||||
self.io_binding.bind_output(name.name)
|
||||
self.session.run_with_iobinding(self.io_binding, None)
|
||||
return self.io_binding.copy_outputs_to_cpu()
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
subfolder: Union[str, Path] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
# This file predefines a few models that the user may want to install.
|
||||
sd-1/main/stable-diffusion-xdl-base:
|
||||
description: Stable Diffusion XL base model - NOT YET RELEASED!! (70 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-base
|
||||
sd-1/main/stable-diffusion-xdl-refiner:
|
||||
description: Stable Diffusion XL refiner model - NOT YET RELEASED!! (60 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-refiner
|
||||
sd-1/main/stable-diffusion-v1-5:
|
||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
@@ -22,6 +16,14 @@ sd-2/main/stable-diffusion-2-inpainting:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||
recommended: False
|
||||
sdxl/main/stable-diffusion-xl-base-0-9:
|
||||
description: Stable Diffusion XL base model (12 GB; access token required)
|
||||
repo_id: stabilityai/stable-diffusion-xl-base-0.9
|
||||
recommended: False
|
||||
sdxl-refiner/main/stable-diffusion-xl-refiner-0-9:
|
||||
description: Stable Diffusion XL refiner model (12 GB; access token required)
|
||||
repo_id: stabilityai/stable-diffusion-xl-refiner-0.9
|
||||
recommended: False
|
||||
sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
repo_id: wavymulder/Analog-Diffusion
|
||||
|
||||
169
invokeai/frontend/web/dist/assets/App-879ff07f.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-879ff07f.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-891779fc.js
vendored
169
invokeai/frontend/web/dist/assets/App-891779fc.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-8d49f92d.css
vendored
Normal file
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-8d49f92d.css
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-ba194473.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-ba194473.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-c1428e4e.js
vendored
125
invokeai/frontend/web/dist/assets/index-c1428e4e.js
vendored
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-c1428e4e.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-ba194473.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
@@ -59,15 +59,8 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
|
||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
{ label: '2x', value: '2' },
|
||||
{ label: '4x', value: '4' },
|
||||
];
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
|
||||
export const NUMPY_RAND_MAX = 2147483647;
|
||||
|
||||
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
|
||||
|
||||
export const NODE_MIN_WIDTH = 250;
|
||||
|
||||
@@ -90,6 +90,7 @@ import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
|
||||
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -228,3 +229,5 @@ addModelSelectedListener();
|
||||
addAppStartedListener();
|
||||
addModelsLoadedListener();
|
||||
addAppConfigReceivedListener();
|
||||
|
||||
addUpscaleRequestedListener();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'session' });
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'upscale' });
|
||||
|
||||
export const upscaleRequested = createAction<{ image_name: string }>(
|
||||
`upscale/upscaleRequested`
|
||||
);
|
||||
|
||||
export const addUpscaleRequestedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: upscaleRequested,
|
||||
effect: async (
|
||||
action,
|
||||
{ dispatch, getState, take, unsubscribe, subscribe }
|
||||
) => {
|
||||
const { image_name } = action.payload;
|
||||
const { esrganModelName } = getState().postprocessing;
|
||||
|
||||
const graph = buildAdHocUpscaleGraph({
|
||||
image_name,
|
||||
esrganModelName,
|
||||
});
|
||||
|
||||
// Create a session to run the graph & wait til it's ready to invoke
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
await take(sessionCreated.fulfilled.match);
|
||||
|
||||
dispatch(sessionReadyToInvoke());
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -11,7 +11,7 @@ interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
|
||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Tooltip label={tooltip} placement="top" hasArrow openDelay={500}>
|
||||
<Box ref={ref} {...others}>
|
||||
<Box>
|
||||
<Text>{label}</Text>
|
||||
|
||||
@@ -5,34 +5,26 @@ import {
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Link,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
} from '@chakra-ui/react';
|
||||
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
|
||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
@@ -42,38 +34,25 @@ import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
FaAsterisk,
|
||||
FaCode,
|
||||
FaCopy,
|
||||
FaDownload,
|
||||
FaExpandArrowsAlt,
|
||||
FaGrinStars,
|
||||
FaHourglassHalf,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
} from 'react-icons/fa';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToImg2Img } from '../../store/actions';
|
||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ gallery, system, postprocessing, ui }, activeTabName) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
shouldConfirmOnDelete,
|
||||
progressImage,
|
||||
} = system;
|
||||
|
||||
const { upscalingLevel, facetoolStrength } = postprocessing;
|
||||
({ gallery, system, ui }, activeTabName) => {
|
||||
const { isProcessing, isConnected, shouldConfirmOnDelete, progressImage } =
|
||||
system;
|
||||
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
@@ -88,10 +67,6 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldConfirmOnDelete,
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
upscalingLevel,
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
@@ -114,27 +89,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
isGFPGANAvailable,
|
||||
isESRGANAvailable,
|
||||
upscalingLevel,
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons,
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
|
||||
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||
useCopyImageToClipboard();
|
||||
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
@@ -155,42 +120,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
const metadata = metadataData?.metadata;
|
||||
|
||||
const handleCopyImageLink = useCallback(() => {
|
||||
const getImageUrl = () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageDTO.image_url.startsWith('http')) {
|
||||
return imageDTO.image_url;
|
||||
}
|
||||
|
||||
return window.location.toString() + imageDTO.image_url;
|
||||
};
|
||||
|
||||
const url = getImageUrl();
|
||||
|
||||
if (!url) {
|
||||
toaster({
|
||||
title: t('toast.problemCopyingImageLink'),
|
||||
status: 'error',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(url).then(() => {
|
||||
toaster({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
}, [toaster, t, imageDTO]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(metadata);
|
||||
}, [metadata, recallAllParameters]);
|
||||
@@ -223,8 +152,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
|
||||
|
||||
const handleClickUpscale = useCallback(() => {
|
||||
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||
}, []);
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
@@ -242,53 +174,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
enabled: () =>
|
||||
Boolean(
|
||||
isUpscalingEnabled &&
|
||||
isESRGANAvailable &&
|
||||
!shouldDisableToolbarButtons &&
|
||||
isConnected &&
|
||||
!isProcessing &&
|
||||
upscalingLevel
|
||||
!isProcessing
|
||||
),
|
||||
},
|
||||
[
|
||||
isUpscalingEnabled,
|
||||
imageDTO,
|
||||
isESRGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
isProcessing,
|
||||
upscalingLevel,
|
||||
]
|
||||
);
|
||||
|
||||
const handleClickFixFaces = useCallback(() => {
|
||||
// selectedImage && dispatch(runFacetool(selectedImage));
|
||||
}, []);
|
||||
|
||||
useHotkeys(
|
||||
'Shift+R',
|
||||
() => {
|
||||
handleClickFixFaces();
|
||||
},
|
||||
{
|
||||
enabled: () =>
|
||||
Boolean(
|
||||
isFaceRestoreEnabled &&
|
||||
isGFPGANAvailable &&
|
||||
!shouldDisableToolbarButtons &&
|
||||
isConnected &&
|
||||
!isProcessing &&
|
||||
facetoolStrength
|
||||
),
|
||||
},
|
||||
|
||||
[
|
||||
isFaceRestoreEnabled,
|
||||
imageDTO,
|
||||
isGFPGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
isProcessing,
|
||||
facetoolStrength,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -297,25 +193,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
[dispatch, shouldShowImageDetails]
|
||||
);
|
||||
|
||||
const handleSendToCanvas = useCallback(() => {
|
||||
if (!imageDTO) return;
|
||||
dispatch(sentImageToCanvas());
|
||||
|
||||
dispatch(setInitialCanvasImage(imageDTO));
|
||||
dispatch(requestCanvasRescale());
|
||||
|
||||
if (activeTabName !== 'unifiedCanvas') {
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
}
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [imageDTO, dispatch, activeTabName, toaster, t]);
|
||||
|
||||
useHotkeys(
|
||||
'i',
|
||||
() => {
|
||||
@@ -337,13 +214,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
|
||||
}, [dispatch, shouldShowProgressInViewer]);
|
||||
|
||||
const handleCopyImage = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
copyImageToClipboard(imageDTO.image_url);
|
||||
}, [copyImageToClipboard, imageDTO]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex
|
||||
@@ -396,72 +266,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
{(isUpscalingEnabled || isFaceRestoreEnabled) && (
|
||||
{isUpscalingEnabled && (
|
||||
<ButtonGroup
|
||||
isAttached={true}
|
||||
isDisabled={shouldDisableToolbarButtons}
|
||||
>
|
||||
{isFaceRestoreEnabled && (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaGrinStars />}
|
||||
aria-label={t('parameters.restoreFaces')}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
rowGap: 4,
|
||||
}}
|
||||
>
|
||||
<FaceRestoreSettings />
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
!imageDTO ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!facetoolStrength
|
||||
}
|
||||
onClick={handleClickFixFaces}
|
||||
>
|
||||
{t('parameters.restoreFaces')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
)}
|
||||
|
||||
{isUpscalingEnabled && (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
aria-label={t('parameters.upscale')}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
}}
|
||||
>
|
||||
<UpscaleSettings />
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
!imageDTO ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!upscalingLevel
|
||||
}
|
||||
onClick={handleClickUpscale}
|
||||
>
|
||||
{t('parameters.upscaleImage')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
)}
|
||||
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
|
||||
</ButtonGroup>
|
||||
)}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
@@ -23,6 +26,7 @@ const ModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: onnxModels } = useGetOnnxModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
@@ -44,17 +48,39 @@ const ModelInputFieldComponent = (
|
||||
});
|
||||
});
|
||||
|
||||
if (onnxModels) {
|
||||
forEach(onnxModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, onnxModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[
|
||||
(mainModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
|
||||
] ||
|
||||
onnxModels?.entities[
|
||||
`${field.value?.base_model}/onnx/${field.value?.model_name}`
|
||||
]) ??
|
||||
null,
|
||||
[
|
||||
field.value?.base_model,
|
||||
field.value?.model_name,
|
||||
mainModels?.entities,
|
||||
onnxModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
@@ -17,7 +18,8 @@ import {
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
baseNodeId: string,
|
||||
modelLoader: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
@@ -40,6 +42,10 @@ export const addLoRAsToGraph = (
|
||||
!(
|
||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === ONNX_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
@@ -74,12 +80,11 @@ export const addLoRAsToGraph = (
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: modelLoader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
|
||||
@@ -9,13 +9,15 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
ONNX_MODEL_LOADER,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph
|
||||
graph: NonNullableGraph,
|
||||
modelLoader: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
const { vae } = state.generation;
|
||||
|
||||
@@ -31,12 +33,12 @@ export const addVAEToGraph = (
|
||||
vae_model: vae,
|
||||
};
|
||||
}
|
||||
|
||||
const isOnnxModel = modelLoader == ONNX_MODEL_LOADER;
|
||||
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
@@ -48,8 +50,8 @@ export const addVAEToGraph = (
|
||||
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
@@ -61,8 +63,8 @@ export const addVAEToGraph = (
|
||||
if (graph.id === INPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||
import { Graph, ESRGANInvocation } from 'services/api/types';
|
||||
import { REALESRGAN as ESRGAN } from './constants';
|
||||
|
||||
type Arg = {
|
||||
image_name: string;
|
||||
esrganModelName: ESRGANModelName;
|
||||
};
|
||||
|
||||
export const buildAdHocUpscaleGraph = ({
|
||||
image_name,
|
||||
esrganModelName,
|
||||
}: Arg): Graph => {
|
||||
const realesrganNode: ESRGANInvocation = {
|
||||
id: ESRGAN,
|
||||
type: 'esrgan',
|
||||
image: { image_name },
|
||||
model_name: esrganModelName,
|
||||
is_intermediate: false,
|
||||
};
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: `adhoc-esrgan-graph`,
|
||||
nodes: {
|
||||
[ESRGAN]: realesrganNode,
|
||||
},
|
||||
edges: [],
|
||||
};
|
||||
|
||||
return graph;
|
||||
};
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@@ -59,6 +60,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@@ -69,16 +73,17 @@ export const buildCanvasImageToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
@@ -87,9 +92,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
id: NOISE,
|
||||
use_cpu,
|
||||
},
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@@ -98,11 +103,11 @@ export const buildCanvasImageToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
type: onnx_model_type ? 'l2l_onnx' : 'l2l',
|
||||
id: LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
@@ -110,7 +115,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
strength,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
type: onnx_model_type ? 'i2l_onnx' : 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
@@ -121,7 +126,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@@ -181,7 +186,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@@ -313,10 +318,10 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
INPAINT_GRAPH,
|
||||
ITERATE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
@@ -63,6 +64,11 @@ export const buildCanvasInpaintGraph = (
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const model_loader = model.model_type.includes('onnx')
|
||||
? ONNX_MODEL_LOADER
|
||||
: MAIN_MODEL_LOADER;
|
||||
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
nodes: {
|
||||
@@ -107,9 +113,9 @@ export const buildCanvasInpaintGraph = (
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@@ -133,7 +139,7 @@ export const buildCanvasInpaintGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@@ -143,7 +149,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@@ -49,7 +50,8 @@ export const buildCanvasTextToImageGraph = (
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@@ -60,16 +62,17 @@ export const buildCanvasTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
@@ -81,15 +84,15 @@ export const buildCanvasTextToImageGraph = (
|
||||
use_cpu,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@@ -98,7 +101,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
},
|
||||
@@ -125,7 +128,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@@ -155,7 +158,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@@ -219,10 +222,10 @@ export const buildCanvasTextToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@@ -82,13 +83,17 @@ export const buildLinearImageToImageGraph = (
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@@ -97,12 +102,12 @@ export const buildLinearImageToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
@@ -112,11 +117,11 @@ export const buildLinearImageToImageGraph = (
|
||||
use_cpu,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
type: onnx_model_type ? 'l2l_onnx' : 'l2l',
|
||||
id: LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
@@ -124,7 +129,7 @@ export const buildLinearImageToImageGraph = (
|
||||
strength,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
type: onnx_model_type ? 'i2l_onnx' : 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
@@ -135,7 +140,7 @@ export const buildLinearImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@@ -145,7 +150,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@@ -366,10 +371,10 @@ export const buildLinearImageToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
@@ -6,10 +6,13 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { BaseModelType, OnnxModelField } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@@ -46,6 +49,8 @@ export const buildLinearTextToImageGraph = (
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@@ -56,12 +61,14 @@ export const buildLinearTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@@ -70,12 +77,12 @@ export const buildLinearTextToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
@@ -87,21 +94,21 @@ export const buildLinearTextToImageGraph = (
|
||||
use_cpu,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@@ -111,7 +118,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@@ -215,10 +222,10 @@ export const buildLinearTextToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
@@ -8,6 +8,7 @@ export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
|
||||
export const VAE_LOADER = 'vae_loader';
|
||||
export const LORA_LOADER = 'lora_loader';
|
||||
export const CLIP_SKIP = 'clip_skip';
|
||||
@@ -20,6 +21,9 @@ export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||
export const IMAGE_COLLECTION = 'image_collection';
|
||||
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
||||
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
export const REALESRGAN = 'esrgan';
|
||||
export const DIVIDE = 'divide';
|
||||
export const SCALE = 'scale_image';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import { BaseModelType, MainModelField, ModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: MainModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_type: model_type as ModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
||||
@@ -0,0 +1,17 @@
|
||||
import { BaseModelType, OnnxModelField, ModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToOnnxModelField = (modelId: string): OnnxModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: OnnxModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
model_type: model_type as ModelType,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
||||
@@ -1,34 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setCodeformerFidelity } from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function CodeformerFidelity() {
|
||||
const isGFPGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isGFPGANAvailable
|
||||
);
|
||||
|
||||
const codeformerFidelity = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.codeformerFidelity
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label={t('parameters.codeformerFidelity')}
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={(v) => dispatch(setCodeformerFidelity(v))}
|
||||
handleReset={() => dispatch(setCodeformerFidelity(1))}
|
||||
value={codeformerFidelity}
|
||||
withReset
|
||||
withSliderMarks
|
||||
withInput
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import { VStack } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import FaceRestoreType from './FaceRestoreType';
|
||||
import FaceRestoreStrength from './FaceRestoreStrength';
|
||||
import CodeformerFidelity from './CodeformerFidelity';
|
||||
|
||||
/**
|
||||
* Displays face-fixing/GFPGAN options (strength).
|
||||
*/
|
||||
const FaceRestoreSettings = () => {
|
||||
const facetoolType = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.facetoolType
|
||||
);
|
||||
|
||||
return (
|
||||
<VStack gap={2} alignItems="stretch">
|
||||
<FaceRestoreType />
|
||||
<FaceRestoreStrength />
|
||||
{facetoolType === 'codeformer' && <CodeformerFidelity />}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default FaceRestoreSettings;
|
||||
@@ -1,34 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setFacetoolStrength } from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function FaceRestoreStrength() {
|
||||
const isGFPGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isGFPGANAvailable
|
||||
);
|
||||
|
||||
const facetoolStrength = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.facetoolStrength
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label={t('parameters.strength')}
|
||||
step={0.05}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={(v) => dispatch(setFacetoolStrength(v))}
|
||||
handleReset={() => dispatch(setFacetoolStrength(0.75))}
|
||||
value={facetoolStrength}
|
||||
withReset
|
||||
withSliderMarks
|
||||
withInput
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setShouldRunFacetool } from 'features/parameters/store/postprocessingSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
|
||||
export default function FaceRestoreToggle() {
|
||||
const isGFPGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isGFPGANAvailable
|
||||
);
|
||||
|
||||
const shouldRunFacetool = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.shouldRunFacetool
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChangeShouldRunFacetool = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRunFacetool(e.target.checked));
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
isChecked={shouldRunFacetool}
|
||||
onChange={handleChangeShouldRunFacetool}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
import { FACETOOL_TYPES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
FacetoolType,
|
||||
setFacetoolType,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function FaceRestoreType() {
|
||||
const facetoolType = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.facetoolType
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangeFacetoolType = (v: string) =>
|
||||
dispatch(setFacetoolType(v as FacetoolType));
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('parameters.type')}
|
||||
data={FACETOOL_TYPES.concat()}
|
||||
value={facetoolType}
|
||||
onChange={handleChangeFacetoolType}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ParamHiresStrength } from './ParamHiresStrength';
|
||||
import { ParamHiresToggle } from './ParamHiresToggle';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
|
||||
|
||||
return { activeLabel };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamHiresCollapse = () => {
|
||||
const { t } = useTranslation();
|
||||
const { activeLabel } = useAppSelector(selector);
|
||||
|
||||
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
|
||||
|
||||
if (!isHiresEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
|
||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||
<ParamHiresToggle />
|
||||
<ParamHiresStrength />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamHiresCollapse);
|
||||
@@ -1,3 +0,0 @@
|
||||
// TODO
|
||||
|
||||
export default {};
|
||||
@@ -1,3 +0,0 @@
|
||||
// TODO
|
||||
|
||||
export default {};
|
||||
@@ -1,51 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
import { setHiresStrength } from 'features/parameters/store/postprocessingSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const hiresStrengthSelector = createSelector(
|
||||
[postprocessingSelector],
|
||||
({ hiresFix, hiresStrength }) => ({ hiresFix, hiresStrength }),
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const ParamHiresStrength = () => {
|
||||
const { hiresFix, hiresStrength } = useAppSelector(hiresStrengthSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleHiresStrength = (v: number) => {
|
||||
dispatch(setHiresStrength(v));
|
||||
};
|
||||
|
||||
const handleHiResStrengthReset = () => {
|
||||
dispatch(setHiresStrength(0.75));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.hiresStrength')}
|
||||
step={0.01}
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
onChange={handleHiresStrength}
|
||||
value={hiresStrength}
|
||||
isInteger={false}
|
||||
withInput
|
||||
withSliderMarks
|
||||
// inputWidth={22}
|
||||
withReset
|
||||
handleReset={handleHiResStrengthReset}
|
||||
isDisabled={!hiresFix}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -1,30 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
/**
|
||||
* Hires Fix Toggle
|
||||
*/
|
||||
export const ParamHiresToggle = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const hiresFix = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.hiresFix
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangeHiresFix = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setHiresFix(e.target.checked));
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label={t('parameters.hiresOptim')}
|
||||
isChecked={hiresFix}
|
||||
onChange={handleChangeHiresFix}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -1,3 +0,0 @@
|
||||
// TODO
|
||||
|
||||
export default {};
|
||||
@@ -12,7 +12,11 @@ import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { modelIdToOnnxModelField } from 'features/nodes/util/modelIdToOnnxModelField';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@@ -27,6 +31,7 @@ const ParamMainModelSelect = () => {
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
@@ -46,17 +51,31 @@ const ParamMainModelSelect = () => {
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
forEach(onnxModels?.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, onnxModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||
(mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ||
|
||||
onnxModels?.entities[
|
||||
`${model?.base_model}/onnx/${model?.model_name}`
|
||||
]) ??
|
||||
null,
|
||||
[mainModels?.entities, model]
|
||||
[mainModels?.entities, model, onnxModels?.entities]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
@@ -65,7 +84,11 @@ const ParamMainModelSelect = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
let newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (v.includes('onnx')) {
|
||||
newModel = modelIdToOnnxModelField(v);
|
||||
}
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
@@ -76,7 +99,7 @@ const ParamMainModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
return isLoading || onnxLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import {
|
||||
ESRGANModelName,
|
||||
esrganModelNameChanged,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const ESRGAN_MODEL_NAMES: SelectItem[] = [
|
||||
{
|
||||
label: 'RealESRGAN x2 Plus',
|
||||
value: 'RealESRGAN_x2plus.pth',
|
||||
tooltip: 'Attempts to retain sharpness, low smoothing',
|
||||
group: 'x2 Upscalers',
|
||||
},
|
||||
{
|
||||
label: 'RealESRGAN x4 Plus',
|
||||
value: 'RealESRGAN_x4plus.pth',
|
||||
tooltip: 'Best for photos and highly detailed images, medium smoothing',
|
||||
group: 'x4 Upscalers',
|
||||
},
|
||||
{
|
||||
label: 'RealESRGAN x4 Plus (anime 6B)',
|
||||
value: 'RealESRGAN_x4plus_anime_6B.pth',
|
||||
tooltip: 'Best for anime/manga, high smoothing',
|
||||
group: 'x4 Upscalers',
|
||||
},
|
||||
{
|
||||
label: 'ESRGAN SRx4',
|
||||
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
|
||||
tooltip: 'Retains sharpness, low smoothing',
|
||||
group: 'x4 Upscalers',
|
||||
},
|
||||
];
|
||||
|
||||
export default function ParamESRGANModel() {
|
||||
const esrganModelName = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.esrganModelName
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = (v: string) =>
|
||||
dispatch(esrganModelNameChanged(v as ESRGANModelName));
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label="ESRGAN Model"
|
||||
value={esrganModelName}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
onChange={handleChange}
|
||||
data={ESRGAN_MODEL_NAMES}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import { Flex, useDisclosure } from '@chakra-ui/react';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaExpandArrowsAlt } from 'react-icons/fa';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import ParamESRGANModel from './ParamRealESRGANModel';
|
||||
|
||||
type Props = { imageDTO?: ImageDTO };
|
||||
|
||||
const ParamUpscalePopover = (props: Props) => {
|
||||
const { imageDTO } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const handleClickUpscale = useCallback(() => {
|
||||
onClose();
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
|
||||
}, [dispatch, imageDTO, onClose]);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
onClick={onOpen}
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
aria-label={t('parameters.upscale')}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
}}
|
||||
>
|
||||
<ParamESRGANModel />
|
||||
<IAIButton
|
||||
size="sm"
|
||||
isDisabled={!imageDTO || isBusy}
|
||||
onClick={handleClickUpscale}
|
||||
>
|
||||
{t('parameters.upscaleImage')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamUpscalePopover;
|
||||
@@ -1,36 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setUpscalingDenoising } from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function UpscaleDenoisingStrength() {
|
||||
const isESRGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isESRGANAvailable
|
||||
);
|
||||
|
||||
const upscalingDenoising = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.upscalingDenoising
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.denoisingStrength')}
|
||||
value={upscalingDenoising}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
onChange={(v) => {
|
||||
dispatch(setUpscalingDenoising(v));
|
||||
}}
|
||||
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
|
||||
withSliderMarks
|
||||
withInput
|
||||
withReset
|
||||
isDisabled={!isESRGANAvailable}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import { UPSCALING_LEVELS } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
UpscalingLevel,
|
||||
setUpscalingLevel,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function UpscaleScale() {
|
||||
const isESRGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isESRGANAvailable
|
||||
);
|
||||
|
||||
const upscalingLevel = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.upscalingLevel
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChangeLevel = (v: string) =>
|
||||
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
disabled={!isESRGANAvailable}
|
||||
label={t('parameters.scale')}
|
||||
value={String(upscalingLevel)}
|
||||
onChange={handleChangeLevel}
|
||||
data={UPSCALING_LEVELS}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
import { VStack } from '@chakra-ui/react';
|
||||
import UpscaleDenoisingStrength from './UpscaleDenoisingStrength';
|
||||
import UpscaleStrength from './UpscaleStrength';
|
||||
import UpscaleScale from './UpscaleScale';
|
||||
|
||||
/**
|
||||
* Displays upscaling/ESRGAN options (level and strength).
|
||||
*/
|
||||
const UpscaleSettings = () => {
|
||||
return (
|
||||
<VStack gap={2} alignItems="stretch">
|
||||
<UpscaleScale />
|
||||
<UpscaleDenoisingStrength />
|
||||
<UpscaleStrength />
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default UpscaleSettings;
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setUpscalingStrength } from 'features/parameters/store/postprocessingSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function UpscaleStrength() {
|
||||
const isESRGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isESRGANAvailable
|
||||
);
|
||||
const upscalingStrength = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.upscalingStrength
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={`${t('parameters.upscale')} ${t('parameters.strength')}`}
|
||||
value={upscalingStrength}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.05}
|
||||
onChange={(v) => dispatch(setUpscalingStrength(v))}
|
||||
handleReset={() => dispatch(setUpscalingStrength(0.75))}
|
||||
withSliderMarks
|
||||
withInput
|
||||
withReset
|
||||
isDisabled={!isESRGANAvailable}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setShouldRunESRGAN } from 'features/parameters/store/postprocessingSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
|
||||
export default function UpscaleToggle() {
|
||||
const isESRGANAvailable = useAppSelector(
|
||||
(state: RootState) => state.system.isESRGANAvailable
|
||||
);
|
||||
|
||||
const shouldRunESRGAN = useAppSelector(
|
||||
(state: RootState) => state.postprocessing.shouldRunESRGAN
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldRunESRGAN(e.target.checked));
|
||||
return (
|
||||
<IAISwitch
|
||||
isDisabled={!isESRGANAvailable}
|
||||
isChecked={shouldRunESRGAN}
|
||||
onChange={handleChangeShouldRunESRGAN}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
||||
export const modelSelected = createAction<MainModelField>(
|
||||
export const modelSelected = createAction<MainModelField | OnnxModelField>(
|
||||
'generation/modelSelected'
|
||||
);
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
setShouldShowAdvancedOptions,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
@@ -54,7 +54,7 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelField | null;
|
||||
model: MainModelField | OnnxModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
@@ -227,7 +227,10 @@ export const generationSlice = createSlice({
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
|
||||
modelChanged: (
|
||||
state,
|
||||
action: PayloadAction<MainModelParam | OnnxModelField | null>
|
||||
) => {
|
||||
state.model = action.payload;
|
||||
|
||||
if (state.model === null) {
|
||||
@@ -259,6 +262,7 @@ export const generationSlice = createSlice({
|
||||
const result = zMainModel.safeParse({
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
|
||||
@@ -1,98 +1,27 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { FACETOOL_TYPES } from 'app/constants';
|
||||
import { ESRGANInvocation } from 'services/api/types';
|
||||
|
||||
export type UpscalingLevel = 2 | 4;
|
||||
|
||||
export type FacetoolType = (typeof FACETOOL_TYPES)[number];
|
||||
export type ESRGANModelName = NonNullable<ESRGANInvocation['model_name']>;
|
||||
|
||||
export interface PostprocessingState {
|
||||
codeformerFidelity: number;
|
||||
facetoolStrength: number;
|
||||
facetoolType: FacetoolType;
|
||||
hiresFix: boolean;
|
||||
hiresStrength: number;
|
||||
shouldLoopback: boolean;
|
||||
shouldRunESRGAN: boolean;
|
||||
shouldRunFacetool: boolean;
|
||||
upscalingLevel: UpscalingLevel;
|
||||
upscalingDenoising: number;
|
||||
upscalingStrength: number;
|
||||
esrganModelName: ESRGANModelName;
|
||||
}
|
||||
|
||||
export const initialPostprocessingState: PostprocessingState = {
|
||||
codeformerFidelity: 0.75,
|
||||
facetoolStrength: 0.75,
|
||||
facetoolType: 'gfpgan',
|
||||
hiresFix: false,
|
||||
hiresStrength: 0.75,
|
||||
shouldLoopback: false,
|
||||
shouldRunESRGAN: false,
|
||||
shouldRunFacetool: false,
|
||||
upscalingLevel: 4,
|
||||
upscalingDenoising: 0.75,
|
||||
upscalingStrength: 0.75,
|
||||
esrganModelName: 'RealESRGAN_x4plus.pth',
|
||||
};
|
||||
|
||||
export const postprocessingSlice = createSlice({
|
||||
name: 'postprocessing',
|
||||
initialState: initialPostprocessingState,
|
||||
reducers: {
|
||||
setFacetoolStrength: (state, action: PayloadAction<number>) => {
|
||||
state.facetoolStrength = action.payload;
|
||||
},
|
||||
setCodeformerFidelity: (state, action: PayloadAction<number>) => {
|
||||
state.codeformerFidelity = action.payload;
|
||||
},
|
||||
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
|
||||
state.upscalingLevel = action.payload;
|
||||
},
|
||||
setUpscalingDenoising: (state, action: PayloadAction<number>) => {
|
||||
state.upscalingDenoising = action.payload;
|
||||
},
|
||||
setUpscalingStrength: (state, action: PayloadAction<number>) => {
|
||||
state.upscalingStrength = action.payload;
|
||||
},
|
||||
setHiresFix: (state, action: PayloadAction<boolean>) => {
|
||||
state.hiresFix = action.payload;
|
||||
},
|
||||
setHiresStrength: (state, action: PayloadAction<number>) => {
|
||||
state.hiresStrength = action.payload;
|
||||
},
|
||||
resetPostprocessingState: (state) => {
|
||||
return {
|
||||
...state,
|
||||
...initialPostprocessingState,
|
||||
};
|
||||
},
|
||||
setShouldRunFacetool: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldRunFacetool = action.payload;
|
||||
},
|
||||
setFacetoolType: (state, action: PayloadAction<FacetoolType>) => {
|
||||
state.facetoolType = action.payload;
|
||||
},
|
||||
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldRunESRGAN = action.payload;
|
||||
},
|
||||
setShouldLoopback: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldLoopback = action.payload;
|
||||
esrganModelNameChanged: (state, action: PayloadAction<ESRGANModelName>) => {
|
||||
state.esrganModelName = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
resetPostprocessingState,
|
||||
setCodeformerFidelity,
|
||||
setFacetoolStrength,
|
||||
setFacetoolType,
|
||||
setHiresFix,
|
||||
setHiresStrength,
|
||||
setShouldLoopback,
|
||||
setShouldRunESRGAN,
|
||||
setShouldRunFacetool,
|
||||
setUpscalingLevel,
|
||||
setUpscalingDenoising,
|
||||
setUpscalingStrength,
|
||||
} = postprocessingSlice.actions;
|
||||
export const { esrganModelNameChanged } = postprocessingSlice.actions;
|
||||
|
||||
export default postprocessingSlice.reducer;
|
||||
|
||||
@@ -126,6 +126,14 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
const zModelType = z.enum([
|
||||
'vae',
|
||||
'lora',
|
||||
'onnx',
|
||||
'main',
|
||||
'controlnet',
|
||||
'embedding',
|
||||
]);
|
||||
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
@@ -137,6 +145,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
export const zMainModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
model_type: zModelType,
|
||||
});
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,6 +14,7 @@ export const modelIdToMainModelParam = (
|
||||
const result = zMainModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
|
||||
@@ -5,7 +5,11 @@ import { useRef } from 'react';
|
||||
import { useHoverDirty } from 'react-use';
|
||||
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const InvokeAILogoComponent = () => {
|
||||
interface Props {
|
||||
showVersion?: boolean;
|
||||
}
|
||||
|
||||
const InvokeAILogoComponent = ({ showVersion = true }: Props) => {
|
||||
const { data: appVersion } = useGetAppVersionQuery();
|
||||
const ref = useRef(null);
|
||||
const isHovered = useHoverDirty(ref);
|
||||
@@ -28,7 +32,7 @@ const InvokeAILogoComponent = () => {
|
||||
invoke <strong>ai</strong>
|
||||
</Text>
|
||||
<AnimatePresence>
|
||||
{isHovered && appVersion && (
|
||||
{showVersion && isHovered && appVersion && (
|
||||
<motion.div
|
||||
key="statusText"
|
||||
initial={{
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { modelIdToOnnxModelField } from 'features/nodes/util/modelIdToOnnxModelField';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => ({ currentModel: state.generation.model }),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { currentModel } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
forEach(onnxModels?.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels, onnxModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[
|
||||
`${currentModel?.base_model}/main/${currentModel?.model_name}`
|
||||
] ||
|
||||
onnxModels?.entities[
|
||||
`${currentModel?.base_model}/onnx/${currentModel?.model_name}`
|
||||
],
|
||||
[mainModels?.entities, onnxModels?.entities, currentModel]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
let modelField = modelIdToMainModelField(v);
|
||||
if (v.includes('onnx')) {
|
||||
modelField = modelIdToOnnxModelField(v);
|
||||
}
|
||||
dispatch(modelSelected(modelField));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading || onnxLoading ? (
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelSelect);
|
||||
@@ -0,0 +1,56 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion1ModelCheckpointConfig,
|
||||
StableDiffusion1ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
|
||||
export type SD1PipelineModel = (
|
||||
| StableDiffusion1ModelCheckpointConfig
|
||||
| StableDiffusion1ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd1PipelineModelsAdapter = createEntityAdapter<SD1PipelineModel>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd1InitialPipelineModelsState =
|
||||
sd1PipelineModelsAdapter.getInitialState();
|
||||
|
||||
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
|
||||
|
||||
export const sd1PipelineModelsSlice = createSlice({
|
||||
name: 'sd1PipelineModels',
|
||||
initialState: sd1InitialPipelineModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd1PipelineModelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-1') return;
|
||||
sd1PipelineModelsAdapter.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD1PipelineModels,
|
||||
selectById: selectByIdSD1PipelineModels,
|
||||
selectEntities: selectEntitiesSD1PipelineModels,
|
||||
selectIds: selectIdsSD1PipelineModels,
|
||||
selectTotal: selectTotalSD1PipelineModels,
|
||||
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
|
||||
(state) => state.sd1pipelinemodels
|
||||
);
|
||||
|
||||
export const { modelAdded } = sd1PipelineModelsSlice.actions;
|
||||
|
||||
export default sd1PipelineModelsSlice.reducer;
|
||||
@@ -0,0 +1,56 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion2ModelCheckpointConfig,
|
||||
StableDiffusion2ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
|
||||
export type SD2PipelineModel = (
|
||||
| StableDiffusion2ModelCheckpointConfig
|
||||
| StableDiffusion2ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd2PipelineModelsAdapater = createEntityAdapter<SD2PipelineModel>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd2InitialPipelineModelsState =
|
||||
sd2PipelineModelsAdapater.getInitialState();
|
||||
|
||||
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
|
||||
|
||||
export const sd2PipelineModelsSlice = createSlice({
|
||||
name: 'sd2PipelineModels',
|
||||
initialState: sd2InitialPipelineModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd2PipelineModelsAdapater.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-2') return;
|
||||
sd2PipelineModelsAdapater.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD2PipelineModels,
|
||||
selectById: selectByIdSD2PipelineModels,
|
||||
selectEntities: selectEntitiesSD2PipelineModels,
|
||||
selectIds: selectIdsSD2PipelineModels,
|
||||
selectTotal: selectTotalSD2PipelineModels,
|
||||
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
|
||||
(state) => state.sd2pipelinemodels
|
||||
);
|
||||
|
||||
export const { modelAdded } = sd2PipelineModelsSlice.actions;
|
||||
|
||||
export default sd2PipelineModelsSlice.reducer;
|
||||
@@ -17,7 +17,7 @@ type ModelListProps = {
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
|
||||
@@ -4,7 +4,6 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
@@ -26,7 +25,6 @@ const TextToImageTabParameters = () => {
|
||||
<ParamVariationCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
<ParamHiresCollapse />
|
||||
<ParamSeamlessCollapse />
|
||||
<ParamAdvancedCollapse />
|
||||
</>
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
ImportModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
OnnxModelConfig,
|
||||
MergeModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
@@ -27,6 +28,8 @@ export type MainModelConfigEntity =
|
||||
| DiffusersModelConfigEntity
|
||||
| CheckpointModelConfigEntity;
|
||||
|
||||
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
|
||||
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
@@ -41,6 +44,7 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||
|
||||
type AnyModelConfigEntity =
|
||||
| MainModelConfigEntity
|
||||
| OnnxModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
@@ -104,6 +108,10 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
|
||||
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
@@ -141,6 +149,38 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getOnnxModels: build.query<EntityState<OnnxModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'onnx' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'OnnxModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'OnnxModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: OnnxModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<OnnxModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return onnxModelsAdapter.setAll(
|
||||
onnxModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
@@ -413,6 +453,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
|
||||
export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
|
||||
@@ -8,6 +8,132 @@ import {
|
||||
} from '@reduxjs/toolkit/query/react';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
|
||||
export type { AddInvocation } from './models/AddInvocation';
|
||||
export type { BoardChanges } from './models/BoardChanges';
|
||||
export type { BoardDTO } from './models/BoardDTO';
|
||||
export type { Body_create_board_image } from './models/Body_create_board_image';
|
||||
export type { Body_remove_board_image } from './models/Body_remove_board_image';
|
||||
export type { Body_upload_image } from './models/Body_upload_image';
|
||||
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||
export type { ClipField } from './models/ClipField';
|
||||
export type { CollectInvocation } from './models/CollectInvocation';
|
||||
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
|
||||
export type { ColorField } from './models/ColorField';
|
||||
export type { CompelInvocation } from './models/CompelInvocation';
|
||||
export type { CompelOutput } from './models/CompelOutput';
|
||||
export type { ConditioningField } from './models/ConditioningField';
|
||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||
export type { ControlField } from './models/ControlField';
|
||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
|
||||
export type { ControlOutput } from './models/ControlOutput';
|
||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
|
||||
export type { DivideInvocation } from './models/DivideInvocation';
|
||||
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
|
||||
export type { Edge } from './models/Edge';
|
||||
export type { EdgeConnection } from './models/EdgeConnection';
|
||||
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
|
||||
export type { FloatLinearRangeInvocation } from './models/FloatLinearRangeInvocation';
|
||||
export type { FloatOutput } from './models/FloatOutput';
|
||||
export type { Graph } from './models/Graph';
|
||||
export type { GraphExecutionState } from './models/GraphExecutionState';
|
||||
export type { GraphInvocation } from './models/GraphInvocation';
|
||||
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
||||
export type { HTTPValidationError } from './models/HTTPValidationError';
|
||||
export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
|
||||
export type { ImageChannelInvocation } from './models/ImageChannelInvocation';
|
||||
export type { ImageConvertInvocation } from './models/ImageConvertInvocation';
|
||||
export type { ImageCropInvocation } from './models/ImageCropInvocation';
|
||||
export type { ImageDTO } from './models/ImageDTO';
|
||||
export type { ImageField } from './models/ImageField';
|
||||
export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation';
|
||||
export type { ImageLerpInvocation } from './models/ImageLerpInvocation';
|
||||
export type { ImageMetadata } from './models/ImageMetadata';
|
||||
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
|
||||
export type { ImageOutput } from './models/ImageOutput';
|
||||
export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
|
||||
export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation';
|
||||
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
||||
export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
|
||||
export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
|
||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
||||
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
||||
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
|
||||
export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||
export type { IntOutput } from './models/IntOutput';
|
||||
export type { IterateInvocation } from './models/IterateInvocation';
|
||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||
export type { LatentsField } from './models/LatentsField';
|
||||
export type { LatentsOutput } from './models/LatentsOutput';
|
||||
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
|
||||
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
|
||||
export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
|
||||
export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
|
||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||
export type { LoraInfo } from './models/LoraInfo';
|
||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||
export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
|
||||
export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
|
||||
export type { ModelInfo } from './models/ModelInfo';
|
||||
export type { ModelLoaderOutput } from './models/ModelLoaderOutput';
|
||||
export type { ModelsList } from './models/ModelsList';
|
||||
export type { ModelType } from './models/ModelType';
|
||||
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||
export type { NoiseOutput } from './models/NoiseOutput';
|
||||
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
|
||||
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
|
||||
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
|
||||
export type { ONNXLatentsToImageInvocation } from './models/ONNXLatentsToImageInvocation';
|
||||
export type { ONNXModelLoaderOutput } from './models/ONNXModelLoaderOutput';
|
||||
export type { ONNXPromptInvocation } from './models/ONNXPromptInvocation';
|
||||
export type { ONNXSD1ModelLoaderInvocation } from './models/ONNXSD1ModelLoaderInvocation';
|
||||
export type { ONNXStableDiffusion1ModelConfig } from './models/ONNXStableDiffusion1ModelConfig';
|
||||
export type { ONNXStableDiffusion2ModelConfig } from './models/ONNXStableDiffusion2ModelConfig';
|
||||
export type { ONNXTextToLatentsInvocation } from './models/ONNXTextToLatentsInvocation';
|
||||
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
|
||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||
export type { PipelineModelField } from './models/PipelineModelField';
|
||||
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
|
||||
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
||||
export type { PromptOutput } from './models/PromptOutput';
|
||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
||||
export type { RangeInvocation } from './models/RangeInvocation';
|
||||
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
|
||||
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
|
||||
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
|
||||
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
|
||||
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
|
||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||
export type { SubModelType } from './models/SubModelType';
|
||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
|
||||
export type { UNetField } from './models/UNetField';
|
||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||
export type { VaeField } from './models/VaeField';
|
||||
export type { VaeModelConfig } from './models/VaeModelConfig';
|
||||
export type { VaeRepo } from './models/VaeRepo';
|
||||
export type { ValidationError } from './models/ValidationError';
|
||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||
export const tagTypes = ['Board', 'Image', 'ImageMetadata', 'Model'];
|
||||
export type ApiFullTagDescription = FullTagDescription<
|
||||
(typeof tagTypes)[number]
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Adds two numbers
|
||||
*/
|
||||
export type AddInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'add';
|
||||
/**
|
||||
* The first number
|
||||
*/
|
||||
'a'?: number;
|
||||
/**
|
||||
* The second number
|
||||
*/
|
||||
'b'?: number;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type BoardChanges = {
|
||||
/**
|
||||
* The board's new name.
|
||||
*/
|
||||
board_name?: string;
|
||||
/**
|
||||
* The name of the board's new cover image.
|
||||
*/
|
||||
cover_image_name?: string;
|
||||
};
|
||||
37
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
37
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Deserialized board record with cover image URL and image count.
|
||||
*/
|
||||
export type BoardDTO = {
|
||||
/**
|
||||
* The unique ID of the board.
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the board.
|
||||
*/
|
||||
board_name: string;
|
||||
/**
|
||||
* The created timestamp of the board.
|
||||
*/
|
||||
created_at: string;
|
||||
/**
|
||||
* The updated timestamp of the board.
|
||||
*/
|
||||
updated_at: string;
|
||||
/**
|
||||
* The deleted timestamp of the board.
|
||||
*/
|
||||
deleted_at?: string;
|
||||
/**
|
||||
* The name of the board's cover image.
|
||||
*/
|
||||
cover_image_name?: string;
|
||||
/**
|
||||
* The number of images in the board.
|
||||
*/
|
||||
image_count: number;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_create_board_image = {
|
||||
/**
|
||||
* The id of the board to add to
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the image to add
|
||||
*/
|
||||
image_name: string;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_remove_board_image = {
|
||||
/**
|
||||
* The id of the board
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the image to remove
|
||||
*/
|
||||
image_name: string;
|
||||
};
|
||||
@@ -0,0 +1,7 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_upload_image = {
|
||||
file: Blob;
|
||||
};
|
||||
@@ -0,0 +1,32 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Canny edge detection for ControlNet
|
||||
*/
|
||||
export type CannyImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'canny_image_processor';
|
||||
/**
|
||||
* The image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The low threshold of the Canny pixel gradient (0-255)
|
||||
*/
|
||||
low_threshold?: number;
|
||||
/**
|
||||
* The high threshold of the Canny pixel gradient (0-255)
|
||||
*/
|
||||
high_threshold?: number;
|
||||
};
|
||||
@@ -0,0 +1,39 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type CkptModelInfo = {
|
||||
/**
|
||||
* A description of the model
|
||||
*/
|
||||
description?: string;
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* The type of the model
|
||||
*/
|
||||
model_type: string;
|
||||
format?: 'ckpt';
|
||||
/**
|
||||
* The path to the model config
|
||||
*/
|
||||
config: string;
|
||||
/**
|
||||
* The path to the model weights
|
||||
*/
|
||||
weights: string;
|
||||
/**
|
||||
* The path to the model VAE
|
||||
*/
|
||||
vae: string;
|
||||
/**
|
||||
* The width of the model
|
||||
*/
|
||||
width?: number;
|
||||
/**
|
||||
* The height of the model
|
||||
*/
|
||||
height?: number;
|
||||
};
|
||||
21
invokeai/frontend/web/src/services/api/models/ClipField.ts
Normal file
21
invokeai/frontend/web/src/services/api/models/ClipField.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { LoraInfo } from './LoraInfo';
|
||||
import type { ModelInfo } from './ModelInfo';
|
||||
|
||||
export type ClipField = {
|
||||
/**
|
||||
* Info to load tokenizer submodel
|
||||
*/
|
||||
tokenizer: ModelInfo;
|
||||
/**
|
||||
* Info to load text_encoder submodel
|
||||
*/
|
||||
text_encoder: ModelInfo;
|
||||
/**
|
||||
* Loras to apply on model loading
|
||||
*/
|
||||
loras: Array<LoraInfo>;
|
||||
};
|
||||
@@ -0,0 +1,26 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Collects values into a collection
|
||||
*/
|
||||
export type CollectInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'collect';
|
||||
/**
|
||||
* The item to collect (all inputs must be of the same type)
|
||||
*/
|
||||
item?: any;
|
||||
/**
|
||||
* The collection, will be provided on execution
|
||||
*/
|
||||
collection?: Array<any>;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Base class for all invocation outputs
|
||||
*/
|
||||
export type CollectInvocationOutput = {
|
||||
type: 'collect_output';
|
||||
/**
|
||||
* The collection of input items
|
||||
*/
|
||||
collection: Array<any>;
|
||||
};
|
||||
22
invokeai/frontend/web/src/services/api/models/ColorField.ts
Normal file
22
invokeai/frontend/web/src/services/api/models/ColorField.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type ColorField = {
|
||||
/**
|
||||
* The red component
|
||||
*/
|
||||
'r': number;
|
||||
/**
|
||||
* The green component
|
||||
*/
|
||||
'g': number;
|
||||
/**
|
||||
* The blue component
|
||||
*/
|
||||
'b': number;
|
||||
/**
|
||||
* The alpha component
|
||||
*/
|
||||
'a': number;
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ClipField } from './ClipField';
|
||||
|
||||
/**
|
||||
* Parse prompt using compel package to conditioning.
|
||||
*/
|
||||
export type CompelInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'compel';
|
||||
/**
|
||||
* Prompt
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* Clip to use
|
||||
*/
|
||||
clip?: ClipField;
|
||||
};
|
||||
@@ -0,0 +1,16 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ConditioningField } from './ConditioningField';
|
||||
|
||||
/**
|
||||
* Compel parser output
|
||||
*/
|
||||
export type CompelOutput = {
|
||||
type?: 'compel_output';
|
||||
/**
|
||||
* Conditioning
|
||||
*/
|
||||
conditioning?: ConditioningField;
|
||||
};
|
||||
@@ -0,0 +1,10 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type ConditioningField = {
|
||||
/**
|
||||
* The name of conditioning data
|
||||
*/
|
||||
conditioning_name: string;
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies content shuffle processing to image
|
||||
*/
|
||||
export type ContentShuffleImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'content_shuffle_image_processor';
|
||||
/**
|
||||
* The image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The pixel resolution for detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* The pixel resolution for the output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* Content shuffle `h` parameter
|
||||
*/
|
||||
'h'?: number;
|
||||
/**
|
||||
* Content shuffle `w` parameter
|
||||
*/
|
||||
'w'?: number;
|
||||
/**
|
||||
* Content shuffle `f` parameter
|
||||
*/
|
||||
'f'?: number;
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
export type ControlField = {
|
||||
/**
|
||||
* The control image
|
||||
*/
|
||||
image: ImageField;
|
||||
/**
|
||||
* The ControlNet model to use
|
||||
*/
|
||||
control_model: string;
|
||||
/**
|
||||
* The weight given to the ControlNet
|
||||
*/
|
||||
control_weight: (number | Array<number>);
|
||||
/**
|
||||
* When the ControlNet is first applied (% of total steps)
|
||||
*/
|
||||
begin_step_percent: number;
|
||||
/**
|
||||
* When the ControlNet is last applied (% of total steps)
|
||||
*/
|
||||
end_step_percent: number;
|
||||
};
|
||||
@@ -0,0 +1,40 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Collects ControlNet info to pass to other nodes
|
||||
*/
|
||||
export type ControlNetInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'controlnet';
|
||||
/**
|
||||
* The control image
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* control model used
|
||||
*/
|
||||
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||
/**
|
||||
* The weight given to the ControlNet
|
||||
*/
|
||||
control_weight?: (number | Array<number>);
|
||||
/**
|
||||
* When the ControlNet is first applied (% of total steps)
|
||||
*/
|
||||
begin_step_percent?: number;
|
||||
/**
|
||||
* When the ControlNet is last applied (% of total steps)
|
||||
*/
|
||||
end_step_percent?: number;
|
||||
};
|
||||
@@ -0,0 +1,17 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { BaseModelType } from './BaseModelType';
|
||||
import type { ControlNetModelFormat } from './ControlNetModelFormat';
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type ControlNetModelConfig = {
|
||||
name: string;
|
||||
base_model: BaseModelType;
|
||||
type: 'controlnet';
|
||||
path: string;
|
||||
description?: string;
|
||||
model_format: ControlNetModelFormat;
|
||||
error?: ModelError;
|
||||
};
|
||||
@@ -0,0 +1,16 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ControlField } from './ControlField';
|
||||
|
||||
/**
|
||||
* node output for ControlNet info
|
||||
*/
|
||||
export type ControlOutput = {
|
||||
type?: 'control_output';
|
||||
/**
|
||||
* The control info
|
||||
*/
|
||||
control?: ControlField;
|
||||
};
|
||||
@@ -0,0 +1,17 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { CkptModelInfo } from './CkptModelInfo';
|
||||
import type { DiffusersModelInfo } from './DiffusersModelInfo';
|
||||
|
||||
export type CreateModelRequest = {
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* The model info
|
||||
*/
|
||||
info: (CkptModelInfo | DiffusersModelInfo);
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Simple inpaint using opencv.
|
||||
*/
|
||||
export type CvInpaintInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'cv_inpaint';
|
||||
/**
|
||||
* The image to inpaint
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The mask to use when inpainting
|
||||
*/
|
||||
mask?: ImageField;
|
||||
};
|
||||
@@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { VaeRepo } from './VaeRepo';
|
||||
|
||||
export type DiffusersModelInfo = {
|
||||
/**
|
||||
* A description of the model
|
||||
*/
|
||||
description?: string;
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* The type of the model
|
||||
*/
|
||||
model_type: string;
|
||||
format?: 'folder';
|
||||
/**
|
||||
* The VAE repo to use for this model
|
||||
*/
|
||||
vae?: VaeRepo;
|
||||
/**
|
||||
* The repo ID to use for this model
|
||||
*/
|
||||
repo_id?: string;
|
||||
/**
|
||||
* The path to the model
|
||||
*/
|
||||
path?: string;
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user