Compare commits

..

73 Commits

Author SHA1 Message Date
Brandon Rising
e8299d0abb Comment out erroniously removed del statement, comment out opt tests 2023-07-18 23:23:34 -04:00
Brandon Rising
a28ab654ef Setup dist folder 2023-07-18 23:18:46 -04:00
Brandon Rising
8699fd7050 Fix invoke UI graphs for onnx 2023-07-18 23:16:51 -04:00
Brandon Rising
9e65470ada Setup dist 2023-07-18 23:07:31 -04:00
Brandon Rising
f4e52fafac Fix as part of merging main in 2023-07-18 23:05:33 -04:00
Brandon Rising
ee7b36cea5 Merge branch 'main' into onnx-testing 2023-07-18 22:56:41 -04:00
Brandon Rising
487455ef2e Add model_type to the model state object 2023-07-18 22:40:27 -04:00
Lincoln Stein
632346b2e2 Release/invokeai 3 0 beta - mkdocs fix (#3821)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [X ] Documentation Update


## Have you discussed this change with the InvokeAI team?
- [ X] Yes
- [ ] No, because:

      
## Description
This PR points mkdocs to the `main` branch again, so that the 3.0.0
documentation appears in gh-pages.

It also makes a minor tweak to the tooltip for model imports, so that
users know that URLs are accepted.

Also rebuilds frontend for use in beta testing.
2023-07-18 21:55:38 -04:00
Brandon Rising
e201ad2f51 Switch to io_binding for run, testing different session options 2023-07-18 21:54:54 -04:00
Lincoln Stein
0f18898865 Merge branch 'release/invokeai-3-0-beta' of github.com:invoke-ai/InvokeAI into release/invokeai-3-0-beta 2023-07-18 21:44:09 -04:00
Lincoln Stein
700131fab2 Pin to transformers 4.30.2
bump version
2023-07-18 21:43:40 -04:00
Lincoln Stein
634d6bb8a6 bump version 2023-07-18 21:42:24 -04:00
Lincoln Stein
2fbf245c3d Merge branch 'main' into release/invokeai-3-0-beta
-- this adds the upscaling support
2023-07-18 21:17:15 -04:00
Lincoln Stein
39c14eb2ac fix pretrained model download to work with xl 2023-07-18 21:10:33 -04:00
Lincoln Stein
c89fa4b635 install, nodes, ui: restore ad-hoc upscaling (#3800)
I've opted to leave out any additional upscaling parameters like scale
and denoising strength, which, from my review of the ESRGAN code, don't
do much:
- scale just resizes the image using CV2 after the AI upscaling, so
that's not particularly useful
- denoising strength is only valid for one class of model, which we are
no longer supporting

If there is demand, we can implement output size/scale UI and handle it
by passing the upscaled image to that a resize/scale node.

I also understand we previously had some functionality to blend the
upscaled image with the original. If that is desired, we would need to
implement that as a node that we can pass the upscaled image to.

Demo:


https://github.com/invoke-ai/InvokeAI/assets/4822129/32eee615-62a1-40ce-a183-87e7d935fbf1

---

[feat(nodes): add RealESRGAN_x2plus.pth, update upscale
nodes](dbc256c5b4)

- add `RealESRGAN_x2plus.pth` model to installer @lstein 
- add `RealESRGAN_x2plus.pth` to `realesrgan` node
- rename `RealESRGAN` to `ESRGAN` in nodes
- make `scale_factor` optional in `img_scale` node

[feat(ui): restore ad-hoc
upscaling](b3fd29e5ad)

- remove face restoration entirely
- add dropdown for ESRGAN model select
- add ad-hoc upscaling graph and workflow
2023-07-18 21:00:17 -04:00
Lincoln Stein
e943913f58 Merge branch 'main' into release/invokeai-3-0-beta 2023-07-18 20:42:10 -04:00
psychedelicious
4ada094c5c tests(nodes): fix tests due to referencing renamed node 2023-07-19 09:45:26 +10:00
Lincoln Stein
893e199677 Merge branch 'main' into feat/ui/upscale 2023-07-18 19:18:55 -04:00
blessedcoolant
3c5a0c95b3 add option to hide version on logo (#3825)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:

      
## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-07-19 11:03:09 +12:00
blessedcoolant
71a07ee5a7 Merge branch 'main' into maryhipp/optional-version 2023-07-19 11:02:24 +12:00
blessedcoolant
2a0a765ec4 Cleanup vram after models offloading (#3826)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Description
There no vram cleanup on models offload which leads to filling vram and
slow generation speed.
2023-07-19 10:17:17 +12:00
Lincoln Stein
ec08151009 add correct requirements for installing SDXL models 2023-07-18 18:15:37 -04:00
blessedcoolant
186e98da5e Merge branch 'main' into fix/mem_cleanup 2023-07-19 10:10:32 +12:00
Eugene Brodsky
dea9a5da7a Avoid crash if unable to modify the model config file (#3824)
* fix whitespace; remove invisible characters
* log error and proceed if unable to modify the model config
2023-07-18 16:33:19 -04:00
Sergey Borisov
bda0000acd Cleanup vram after models offloading, tweak to cleanup local variable references on ram offload 2023-07-18 23:21:18 +03:00
Mary Hipp
4b678f2416 add toggle to not show version on logo 2023-07-18 16:16:35 -04:00
Brandon Rising
869f418b03 Setup onnx on linear text2image 2023-07-18 14:27:54 -04:00
Brandon Rising
35d5ef9118 Emit step completions 2023-07-18 12:35:07 -04:00
psychedelicious
42c440c73f Merge branch 'main' into feat/ui/upscale 2023-07-18 22:08:02 +10:00
psychedelicious
416afd2781 chore(ui): regen types 2023-07-18 15:04:43 +10:00
psychedelicious
afa84a149c feat(ui): restore ad-hoc upscaling
- remove face restoration entirely
- add dropdown for ESRGAN model select
- add ad-hoc upscaling graph and workflow
2023-07-18 14:57:47 +10:00
psychedelicious
be659364c2 chore(ui): regen types 2023-07-18 14:55:39 +10:00
psychedelicious
56098f370c feat(nodes): add RealESRGAN_x2plus.pth, update upscale nodes
- add `RealESRGAN_x2plus.pth` model to installer
- add `RealESRGAN_x2plus.pth` to `realesrgan` node
- rename `RealESRGAN` to `ESRGAN` in nodes
- make `scale_factor` optional in `img_scale` node
2023-07-18 14:55:18 +10:00
Brandon Rising
bcce70fca6 Testing different session opts, added timings for testing 2023-07-17 16:27:33 -04:00
Brandon Rising
932112b640 testing being super wasteful with data 2023-07-16 00:17:33 -04:00
Brandon Rising
91112167b1 Fix syntax err 2023-07-15 23:56:48 -04:00
Brandon Rising
bd7b59910d Testing onnx in new ui updates 2023-07-14 14:24:15 -04:00
Brandon Rising
524888bf3b Merge branch 'main' into feat/onnx 2023-07-13 14:23:57 -04:00
blessedcoolant
0327eae509 chore: Regen API 2023-06-23 05:21:06 +12:00
blessedcoolant
bb85608890 Merge branch 'main' into feat/onnx 2023-06-23 05:18:41 +12:00
Sergey Borisov
6c7668aaca Update onnx model structure, change code according 2023-06-22 20:03:17 +03:00
Sergey Borisov
7759b3f75a Small refactor 2023-06-21 04:24:25 +03:00
Sergey Borisov
4d337f6abc ONNX Model/runtime first implementation 2023-06-21 02:12:21 +03:00
Sergey Borisov
92c86fd0b8 Set model type to const value in openapi schema, add model format enums to model schema(as they not not referenced in case of Literal definition) 2023-06-20 03:44:58 +03:00
Sergey Borisov
46dc751139 Update model format field to use enums 2023-06-20 03:30:09 +03:00
Sergey Borisov
4cefe37723 Rename format to model_format(still named format when work with config) 2023-06-20 03:25:08 +03:00
Sergey Borisov
82b73c50a0 Remove default model logic 2023-06-20 03:13:10 +03:00
blessedcoolant
7df7a95299 Merge branch 'main' into model-manager-ui-30 2023-06-19 23:26:11 +12:00
blessedcoolant
85b4b359c2 tweal: UI colors 2023-06-19 23:16:14 +12:00
blessedcoolant
cfe81b5e00 fix: Adjust the Schedular select width
So the long names do not get cut off.
2023-06-19 23:05:32 +12:00
blessedcoolant
b0c4451324 Merge branch 'main' into model-manager-ui-30 2023-06-19 23:02:59 +12:00
blessedcoolant
d4931522d4 Merge branch 'main' into model-manager-ui-30 2023-06-19 22:53:13 +12:00
blessedcoolant
17e2a35228 fix: merge conflicts 2023-06-18 22:25:48 +12:00
blessedcoolant
91016d8b29 Merge branch 'main' into model-manager-ui-30 2023-06-18 22:23:18 +12:00
blessedcoolant
9fda21cf40 Revert "feat: Port Schedulers to Mantine"
This reverts commit e0c105f413.
2023-06-18 22:22:56 +12:00
blessedcoolant
809ec7163e fix: Remove type from Model type name 2023-06-18 19:41:30 +12:00
blessedcoolant
7c9a939b47 fix: Unserialization key issue 2023-06-18 19:38:15 +12:00
blessedcoolant
9634c96020 revert: getModels to receivedModels 2023-06-18 19:35:46 +12:00
blessedcoolant
e0c105f413 feat: Port Schedulers to Mantine 2023-06-18 19:31:53 +12:00
blessedcoolant
f0bf32c476 Merge branch 'main' into model-manager-ui-30 2023-06-18 17:37:34 +12:00
blessedcoolant
28373dbb98 cleanup: Updated model slice names to be more descriptive
Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models.
2023-06-18 17:36:23 +12:00
blessedcoolant
4133d77772 wip: Move Model Selector to own file 2023-06-18 09:19:13 +12:00
blessedcoolant
61c426f502 feat: Enable 2.x Model Generation in Linear UI 2023-06-18 08:27:13 +12:00
blessedcoolant
bf0577c882 fix: 2.1 models breaking generation
Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
2023-06-18 08:26:25 +12:00
blessedcoolant
24673fd859 chore: Rebuild API - base_model and type added 2023-06-18 07:50:28 +12:00
Sergey Borisov
dc669d1447 Add name, base_mode, type fields to model info 2023-06-17 22:48:44 +03:00
blessedcoolant
ce4110b9f4 wip: Add 2.x Models to the Model List 2023-06-18 07:01:44 +12:00
blessedcoolant
0f3b7d2b3d chore: Rebuild API with new Model API names 2023-06-18 03:00:16 +12:00
Sergey Borisov
16dc78f6c6 Generate config names for openapi 2023-06-17 17:15:36 +03:00
blessedcoolant
7a66856785 wip: Update Linear UI Txt2Img and Img2Img Graphs
Update the text to imaeg and image to image graphs to work with the new model loader. Currently only supports 1.x models. Will update this soon to make it work with all models.
2023-06-18 01:38:01 +12:00
blessedcoolant
c8dfa49d86 fix: Update missing name types to new names 2023-06-17 22:04:28 +12:00
blessedcoolant
76dd749b1e chore: Rebuild API 2023-06-17 21:29:32 +12:00
blessedcoolant
67d05d2066 chore: Update model config type names 2023-06-17 21:28:43 +12:00
215 changed files with 7229 additions and 1273 deletions

View File

@@ -1,6 +1,14 @@
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import (Blend, Conjunction,

View File

@@ -518,8 +518,8 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

View File

@@ -22,7 +22,8 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import ModelPatcher
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)

View File

@@ -54,6 +54,7 @@ class MainModelField(BaseModel):
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class LoRAModelField(BaseModel):
@@ -221,6 +222,9 @@ class LoraLoaderInvocation(BaseInvocation):
base_model = self.lora.base_model
lora_name = self.lora.model_name
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,

View File

@@ -0,0 +1,591 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import re
import inspect
from pydantic import BaseModel, Field, validator
import torch
import numpy as np
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ONNXModelPatcher
from ...backend.util import choose_torch_device
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend import BaseModelType, ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.stable_diffusion import PipelineIntermediateState
from tqdm import tqdm
from .model import ClipField
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
from .compel import CompelOutput
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = Field(default="", description="Prompt")
clip: ClipField = Field(None, description="Clip to use")
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
#stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
#)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
text_encoder.create_session()
# copy from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
text_inputs = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
"""
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
"""
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_encoder.release_session()
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (prompt_embeds, None))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# Text to image
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context: InvocationContext) -> LatentsOutput:
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
if isinstance(c, torch.Tensor):
c = c.cpu().numpy()
if isinstance(uc, torch.Tensor):
uc = uc.cpu().numpy()
device = torch.device(choose_torch_device())
prompt_embeds = np.concatenate([uc, c])
latents = context.services.latents.get(self.noise.latents_name)
if isinstance(latents, torch.Tensor):
latents = latents.cpu().numpy()
# TODO: better execution device handling
latents = latents.astype(np.float16)
# get the initial random noise unless the user supplied it
do_classifier_free_guidance = True
#latents_dtype = prompt_embeds.dtype
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
#if latents.shape != latents_shape:
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
def torch2numpy(latent: torch.Tensor):
return latent.cpu().numpy()
def numpy2torch(latent, device):
return torch.from_numpy(latent).to(device)
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:
unet.create_session()
timestep_dtype = next(
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
import time
times = []
for i in tqdm(range(len(scheduler.timesteps))):
t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
start_time = time.time()
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
times.append(time.time() - start_time)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler.step(
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
)
latents = torch2numpy(scheduler_output.prev_sample)
state = PipelineIntermediateState(
run_id= "test",
step=i,
timestep=timestep,
latents=scheduler_output.prev_sample
)
dispatch_progress(
self,
context=context,
source_node_id=source_node_id,
intermediate_state=state
)
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
print(times)
unet.release_session()
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
# Latent to image
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
if self.vae.vae.submodel != SubModelType.VaeDecoder:
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
with vae_info as vae:
vae.create_session()
# copied from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image = VaeImageProcessor.numpy_to_pil(image)[0]
vae.release_session()
torch.cuda.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ONNXModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
model_name = "stable-diffusion-v1-5"
base_model = BaseModelType.StableDiffusion1
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
):
raise Exception(f"Unkown model name: {model_name}!")
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeEncoder,
),
)
)
class OnnxModelField(BaseModel):
"""Onnx model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
model: OnnxModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Onnx Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.ONNX
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.VaeEncoder,
),
)
)

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Union, cast
from pathlib import Path
from typing import Literal, Union
import cv2 as cv
import numpy as np
@@ -16,19 +16,20 @@ from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
ESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"RealESRGAN_x2plus.pth",
]
class RealESRGANInvocation(BaseInvocation):
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
type: Literal["realesrgan"] = "realesrgan"
type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field(
model_name: ESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)
@@ -73,19 +74,17 @@ class RealESRGANInvocation(BaseInvocation):
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
elif self.model_name in ["RealESRGAN_x2plus.pth"]:
# x2 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)

View File

@@ -466,7 +466,6 @@ class Generator:
dtype=samples.dtype,
device=samples.device,
)
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2)

View File

@@ -222,7 +222,7 @@ def download_conversion_models():
# ---------------------------------------------
def download_realesrgan():
logger.info("Installing RealESRGAN models...")
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
@@ -239,6 +239,11 @@ def download_realesrgan():
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description = "RealESRGAN_x2plus.pth",
),
]
for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])

View File

@@ -10,7 +10,7 @@ from tempfile import TemporaryDirectory
from typing import List, Dict, Callable, Union, Set
import requests
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf
@@ -310,6 +310,8 @@ class ModelInstall(object):
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
elif location.is_dir():
return location.name
else:
return location.stem
@@ -365,7 +367,7 @@ class ModelInstall(object):
model = None
for revision in revisions:
try:
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
if model:

View File

@@ -3,6 +3,7 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .lora import ModelPatcher, ONNXModelPatcher
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@@ -6,11 +6,22 @@ from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from onnx import numpy_helper
from onnxruntime import OrtValue
import numpy as np
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
# TODO: rename and split this file
class LoRALayerBase:
#rank: Optional[int]
#alpha: Optional[float]
@@ -708,3 +719,185 @@ class TextualInversionManager(BaseTextualInversionManager):
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict()
try:
blended_loras = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
node_names = dict()
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
for layer_key, lora_weight in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_names or gemm_key in node_names:
if conv_key in node_names:
conv_node = model.nodes[node_names[conv_key]]
else:
conv_node = model.nodes[node_names[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
orig_weight = model.tensors[weight_name]
if orig_weight.shape[-2:] == (1, 1):
if lora_weight.shape[-2:] == (1, 1):
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
else:
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
new_weight = np.expand_dims(new_weight, (2, 3))
else:
if orig_weight.shape != lora_weight.shape:
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
else:
new_weight = orig_weight + lora_weight
orig_weights[weight_name] = orig_weight
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
elif matmul_key in node_names:
weight_node = model.nodes[node_names[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
orig_weight = model.tensors[matmul_name]
new_weight = orig_weight + lora_weight.transpose()
orig_weights[matmul_name] = orig_weight
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for name, orig_weight in orig_weights.items():
model.tensors[name] = orig_weight
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_embeddings = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
embeddings = np.concatenate(
(
np.copy(orig_embeddings),
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
),
axis=0,
)
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
yield ti_tokenizer, ti_manager
finally:
# restore
if orig_embeddings is not None:
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings

View File

@@ -328,6 +328,25 @@ class ModelCache(object):
refs = sys.getrefcount(cache_entry.model)
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
#break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
@@ -363,6 +382,9 @@ class ModelCache(object):
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
vram_in_use += mem.vram_used # note vram_used is negative
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
gc.collect()
torch.cuda.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()

View File

@@ -106,16 +106,16 @@ providing information about a model defined in models.yaml. For example:
>>> models = mgr.list_models()
>>> json.dumps(models[0])
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
"model_format": "diffusers",
"name": "canny",
"base_model": "sd-1",
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
"model_format": "diffusers",
"name": "canny",
"base_model": "sd-1",
"type": "controlnet"
}
You can filter by model type and base model as shown here:
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
base_model=BaseModelType.StableDiffusion1)
for c in controlnets:
@@ -140,14 +140,14 @@ Layout of the `models` directory:
models
├── sd-1
   ├── controlnet
   ├── lora
   ├── main
   └── embedding
├── controlnet
├── lora
├── main
└── embedding
├── sd-2
   ├── controlnet
   ├── lora
   ├── main
├── controlnet
├── lora
├── main
│ └── embedding
└── core
├── face_reconstruction
@@ -195,7 +195,7 @@ name, base model, type and a dict of model attributes. See
`invokeai/backend/model_management/models` for the attributes required
by each model type.
A model can be deleted using `del_model()`, providing the same
A model can be deleted using `del_model()`, providing the same
identifying information as `get_model()`
The `heuristic_import()` method will take a set of strings
@@ -304,7 +304,7 @@ class ModelManager(object):
logger: types.ModuleType = logger,
):
"""
Initialize with the path to the models.yaml config file.
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
@@ -323,7 +323,7 @@ class ModelManager(object):
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
# TODO: metadata not found
# TODO: version check
self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger
self.cache = ModelCache(
@@ -431,7 +431,7 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
@@ -456,7 +456,7 @@ class ModelManager(object):
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
@@ -489,7 +489,7 @@ class ModelManager(object):
self.cache_keys[model_key].add(model_context.key)
model_hash = "<NO_HASH>" # TODO:
return ModelInfo(
context = model_context,
name = model_name,
@@ -518,7 +518,7 @@ class ModelManager(object):
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Return a list of (str, BaseModelType, ModelType) corresponding to all models
Return a list of (str, BaseModelType, ModelType) corresponding to all models
known to the configuration.
"""
return [(self.parse_key(x)) for x in self.models.keys()]
@@ -692,12 +692,12 @@ class ModelManager(object):
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
new_base = new_base or base_model
@@ -726,7 +726,7 @@ class ModelManager(object):
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model (
self,
model_name: str,
@@ -776,12 +776,12 @@ class ModelManager(object):
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
checkpoint_path.unlink()
return result
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@@ -824,10 +824,14 @@ class ModelManager(object):
assert config_file_path is not None,'no config file path to write to'
config_file_path = self.app_config.root_path / config_file_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.replace(tmpfile, config_file_path)
try:
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.replace(tmpfile, config_file_path)
except OSError as err:
self.logger.warning(f"Could not modify the config file at {config_file_path}")
self.logger.warning(err)
def preamble(self) -> str:
"""
@@ -977,13 +981,12 @@ class ModelManager(object):
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)
self.commit()
self.commit()
return successfully_installed

View File

@@ -23,7 +23,7 @@ class ModelProbeInfo(object):
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['diffusers','checkpoint', 'lycoris']
format: Literal['diffusers','checkpoint', 'lycoris', 'olive']
image_size: int
class ProbeBase(object):

View File

@@ -10,8 +10,11 @@ from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Main: StableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
@@ -19,6 +22,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
@@ -32,6 +36,7 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
@@ -40,6 +45,7 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,

View File

@@ -8,13 +8,19 @@ from abc import ABCMeta, abstractmethod
from pathlib import Path
from picklescan.scanner import scan_file_path
import torch
import numpy as np
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
class InvalidModelException(Exception):
pass
@@ -29,6 +35,7 @@ class BaseModelType(str, Enum):
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
@@ -42,6 +49,8 @@ class SubModelType(str, Enum):
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
@@ -254,16 +263,18 @@ class DiffusersModel(ModelBase):
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
os.path.join(self.model_path, child_type.value),
#subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
print("====ERR LOAD====")
print(f"{variant}: {e}")
import traceback
traceback.print_exc()
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@@ -430,3 +441,188 @@ class SilenceWarnings(object):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.model.data[key].numpy()
def __setitem__(self, key: str, value: np.ndarray):
new_node = numpy_helper.from_array(value)
# set_external_data(new_node, location="in-memory-location")
new_node.name = key
# new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__
def __contains__(self, key: str):
return key in self.model.data
def items(self):
raise NotImplementedError("tensor.items")
#return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.model.data.keys()
def values(self):
raise NotImplementedError("tensor.values")
#return [obj for obj in self.raw_proto]
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.raw_proto[self.indexes[key]]
def __setitem__(self, key: str, value):
index = self.indexes[key]
del self.raw_proto[index]
self.raw_proto.insert(index, value)
# __delitem__
def __contains__(self, key: str):
return key in self.indexes
def items(self):
return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path
self.session = None
self.provider = provider or "CPUExecutionProvider"
"""
self.data_path = self.path + "_data"
if not os.path.exists(self.data_path):
print(f"Moving model tensors to separate file: {self.data_path}")
tmp_proto = onnx.load(model_path, load_external_data=True)
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
del tmp_proto
gc.collect()
self.proto = onnx.load(model_path, load_external_data=False)
"""
self.proto = onnx.load(model_path, load_external_data=True)
self.data = dict()
for tensor in self.proto.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
self.data[name] = orv
# set_external_data(tensor, location="in-memory-location")
tensor.name = name
# tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node)
self.initializers = self._access_helper(self.proto.graph.initializer)
# print(self.proto.graph.input)
# print(self.proto.graph.initializer)
self.tensors = self._tensor_access(self)
# TODO: integrate with model manager/cache
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
#self._external_data.update(**external_data)
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
# sess.enable_profiling = True
# sess.intra_op_num_threads = 1
# sess.inter_op_num_threads = 1
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
self.io_binding = self.session.io_binding()
def release_session(self):
self.session = None
import gc
gc.collect()
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
output_names = self.session.get_outputs()
for k in inputs:
self.io_binding.bind_cpu_input(k, inputs[k])
for name in output_names:
self.io_binding.bind_output(name.name)
self.session.run_with_iobinding(self.io_binding, None)
return self.io_binding.copy_outputs_to_cpu()
# compatability with diffusers load code
@classmethod
def from_pretrained(
cls,
model_id: Union[str, Path],
subfolder: Union[str, Path] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
file_name = file_name or ONNX_WEIGHTS_NAME
if os.path.isdir(model_id):
model_path = model_id
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
model_path = os.path.join(model_path, file_name)
else:
model_path = model_id
# load model from local directory
if not os.path.isfile(model_path):
raise Exception(f"Model not found: {model_path}")
# TODO: session options
return cls(model_path, provider=provider)

View File

@@ -0,0 +1,156 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path

View File

@@ -1,10 +1,4 @@
# This file predefines a few models that the user may want to install.
sd-1/main/stable-diffusion-xdl-base:
description: Stable Diffusion XL base model - NOT YET RELEASED!! (70 GB)
repo_id: stabilityai/stable-diffusion-xl-base
sd-1/main/stable-diffusion-xdl-refiner:
description: Stable Diffusion XL refiner model - NOT YET RELEASED!! (60 GB)
repo_id: stabilityai/stable-diffusion-xl-refiner
sd-1/main/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5
@@ -22,6 +16,14 @@ sd-2/main/stable-diffusion-2-inpainting:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting
recommended: False
sdxl/main/stable-diffusion-xl-base-0-9:
description: Stable Diffusion XL base model (12 GB; access token required)
repo_id: stabilityai/stable-diffusion-xl-base-0.9
recommended: False
sdxl-refiner/main/stable-diffusion-xl-refiner-0-9:
description: Stable Diffusion XL refiner model (12 GB; access token required)
repo_id: stabilityai/stable-diffusion-xl-refiner-0.9
recommended: False
sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-c1428e4e.js"></script>
<script type="module" crossorigin src="./assets/index-ba194473.js"></script>
</head>
<body dir="ltr">

View File

@@ -59,15 +59,8 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
{ label: '2x', value: '2' },
{ label: '4x', value: '4' },
];
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647;
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
export const NODE_MIN_WIDTH = 250;

View File

@@ -90,6 +90,7 @@ import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
export const listenerMiddleware = createListenerMiddleware();
@@ -228,3 +229,5 @@ addModelSelectedListener();
addAppStartedListener();
addModelsLoadedListener();
addAppConfigReceivedListener();
addUpscaleRequestedListener();

View File

@@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCreated } from 'services/api/thunks/session';
import { serializeError } from 'serialize-error';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'session' });

View File

@@ -0,0 +1,37 @@
import { createAction } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'upscale' });
export const upscaleRequested = createAction<{ image_name: string }>(
`upscale/upscaleRequested`
);
export const addUpscaleRequestedListener = () => {
startAppListening({
actionCreator: upscaleRequested,
effect: async (
action,
{ dispatch, getState, take, unsubscribe, subscribe }
) => {
const { image_name } = action.payload;
const { esrganModelName } = getState().postprocessing;
const graph = buildAdHocUpscaleGraph({
image_name,
esrganModelName,
});
// Create a session to run the graph & wait til it's ready to invoke
dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
},
});
};

View File

@@ -11,7 +11,7 @@ interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<Tooltip label={tooltip} placement="top" hasArrow>
<Tooltip label={tooltip} placement="top" hasArrow openDelay={500}>
<Box ref={ref} {...others}>
<Box>
<Text>{label}</Text>

View File

@@ -5,34 +5,26 @@ import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
@@ -42,38 +34,25 @@ import { useTranslation } from 'react-i18next';
import {
FaAsterisk,
FaCode,
FaCopy,
FaDownload,
FaExpandArrowsAlt,
FaGrinStars,
FaHourglassHalf,
FaQuoteRight,
FaSeedling,
FaShare,
FaShareAlt,
} from 'react-icons/fa';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
} from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ gallery, system, postprocessing, ui }, activeTabName) => {
const {
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
shouldConfirmOnDelete,
progressImage,
} = system;
const { upscalingLevel, facetoolStrength } = postprocessing;
({ gallery, system, ui }, activeTabName) => {
const { isProcessing, isConnected, shouldConfirmOnDelete, progressImage } =
system;
const {
shouldShowImageDetails,
@@ -88,10 +67,6 @@ const currentImageButtonsSelector = createSelector(
shouldConfirmOnDelete,
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
upscalingLevel,
facetoolStrength,
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
shouldShowImageDetails,
activeTabName,
@@ -114,27 +89,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const {
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
upscalingLevel,
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
activeTabName,
lastSelectedImage,
shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector);
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
const toaster = useAppToaster();
const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
@@ -155,42 +120,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const metadata = metadataData?.metadata;
const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => {
if (!imageDTO) {
return;
}
if (imageDTO.image_url.startsWith('http')) {
return imageDTO.image_url;
}
return window.location.toString() + imageDTO.image_url;
};
const url = getImageUrl();
if (!url) {
toaster({
title: t('toast.problemCopyingImageLink'),
status: 'error',
duration: 2500,
isClosable: true,
});
return;
}
navigator.clipboard.writeText(url).then(() => {
toaster({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
}, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
@@ -223,8 +152,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage));
}, []);
if (!imageDTO) {
return;
}
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
}, [dispatch, imageDTO]);
const handleDelete = useCallback(() => {
if (!imageDTO) {
@@ -242,53 +174,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
enabled: () =>
Boolean(
isUpscalingEnabled &&
isESRGANAvailable &&
!shouldDisableToolbarButtons &&
isConnected &&
!isProcessing &&
upscalingLevel
!isProcessing
),
},
[
isUpscalingEnabled,
imageDTO,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
isProcessing,
upscalingLevel,
]
);
const handleClickFixFaces = useCallback(() => {
// selectedImage && dispatch(runFacetool(selectedImage));
}, []);
useHotkeys(
'Shift+R',
() => {
handleClickFixFaces();
},
{
enabled: () =>
Boolean(
isFaceRestoreEnabled &&
isGFPGANAvailable &&
!shouldDisableToolbarButtons &&
isConnected &&
!isProcessing &&
facetoolStrength
),
},
[
isFaceRestoreEnabled,
imageDTO,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
isProcessing,
facetoolStrength,
]
);
@@ -297,25 +193,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[dispatch, shouldShowImageDetails]
);
const handleSendToCanvas = useCallback(() => {
if (!imageDTO) return;
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
dispatch(setActiveTab('unifiedCanvas'));
}
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
}, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys(
'i',
() => {
@@ -337,13 +214,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return (
<>
<Flex
@@ -396,72 +266,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
{(isUpscalingEnabled || isFaceRestoreEnabled) && (
{isUpscalingEnabled && (
<ButtonGroup
isAttached={true}
isDisabled={shouldDisableToolbarButtons}
>
{isFaceRestoreEnabled && (
<IAIPopover
triggerComponent={
<IAIIconButton
icon={<FaGrinStars />}
aria-label={t('parameters.restoreFaces')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
rowGap: 4,
}}
>
<FaceRestoreSettings />
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
onClick={handleClickFixFaces}
>
{t('parameters.restoreFaces')}
</IAIButton>
</Flex>
</IAIPopover>
)}
{isUpscalingEnabled && (
<IAIPopover
triggerComponent={
<IAIIconButton
icon={<FaExpandArrowsAlt />}
aria-label={t('parameters.upscale')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
}}
>
<UpscaleSettings />
<IAIButton
isDisabled={
!isESRGANAvailable ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
onClick={handleClickUpscale}
>
{t('parameters.upscaleImage')}
</IAIButton>
</Flex>
</IAIPopover>
)}
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
</ButtonGroup>
)}

View File

@@ -12,7 +12,10 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = (
@@ -23,6 +26,7 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: onnxModels } = useGetOnnxModelsQuery();
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const data = useMemo(() => {
@@ -44,17 +48,39 @@ const ModelInputFieldComponent = (
});
});
if (onnxModels) {
forEach(onnxModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
}
return data;
}, [mainModels]);
}, [mainModels, onnxModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
mainModels?.entities[
(mainModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
] ||
onnxModels?.entities[
`${field.value?.base_model}/onnx/${field.value?.model_name}`
]) ??
null,
[
field.value?.base_model,
field.value?.model_name,
mainModels?.entities,
onnxModels?.entities,
]
);
const handleChangeModel = useCallback(

View File

@@ -9,6 +9,7 @@ import {
CLIP_SKIP,
LORA_LOADER,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
@@ -17,7 +18,8 @@ import {
export const addLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
baseNodeId: string,
modelLoader: string = MAIN_MODEL_LOADER
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
@@ -40,6 +42,10 @@ export const addLoRAsToGraph = (
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === ONNX_MODEL_LOADER &&
['unet'].includes(e.source.field)
)
);
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
@@ -74,12 +80,11 @@ export const addLoRAsToGraph = (
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
node_id: modelLoader,
field: 'unet',
},
destination: {

View File

@@ -9,13 +9,15 @@ import {
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
METADATA_ACCUMULATOR,
ONNX_MODEL_LOADER,
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
export const addVAEToGraph = (
state: RootState,
graph: NonNullableGraph
graph: NonNullableGraph,
modelLoader: string = MAIN_MODEL_LOADER
): void => {
const { vae } = state.generation;
@@ -31,12 +33,12 @@ export const addVAEToGraph = (
vae_model: vae,
};
}
const isOnnxModel = modelLoader == ONNX_MODEL_LOADER;
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
node_id: isAutoVae ? modelLoader : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
@@ -48,8 +50,8 @@ export const addVAEToGraph = (
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
node_id: isAutoVae ? modelLoader : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
@@ -61,8 +63,8 @@ export const addVAEToGraph = (
if (graph.id === INPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
node_id: isAutoVae ? modelLoader : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: INPAINT,

View File

@@ -0,0 +1,32 @@
import { NonNullableGraph } from 'features/nodes/types/types';
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import { Graph, ESRGANInvocation } from 'services/api/types';
import { REALESRGAN as ESRGAN } from './constants';
type Arg = {
image_name: string;
esrganModelName: ESRGANModelName;
};
export const buildAdHocUpscaleGraph = ({
image_name,
esrganModelName,
}: Arg): Graph => {
const realesrganNode: ESRGANInvocation = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: false,
};
const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
nodes: {
[ESRGAN]: realesrganNode,
},
edges: [],
};
return graph;
};

View File

@@ -18,6 +18,7 @@ import {
LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
@@ -59,6 +60,9 @@ export const buildCanvasImageToImageGraph = (
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
const onnx_model_type = model.model_type.includes('onnx');
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -69,16 +73,17 @@ export const buildCanvasImageToImageGraph = (
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH,
nodes: {
[POSITIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
@@ -87,9 +92,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE,
use_cpu,
},
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
[model_loader]: {
type: model_loader,
id: model_loader,
model,
},
[CLIP_SKIP]: {
@@ -98,11 +103,11 @@ export const buildCanvasImageToImageGraph = (
skipped_layers: clipSkip,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
type: onnx_model_type ? 'l2l_onnx' : 'l2l',
id: LATENTS_TO_LATENTS,
cfg_scale,
scheduler,
@@ -110,7 +115,7 @@ export const buildCanvasImageToImageGraph = (
strength,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
type: onnx_model_type ? 'i2l_onnx' : 'i2l',
id: IMAGE_TO_LATENTS,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
@@ -121,7 +126,7 @@ export const buildCanvasImageToImageGraph = (
edges: [
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'clip',
},
destination: {
@@ -181,7 +186,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'unet',
},
destination: {
@@ -313,10 +318,10 @@ export const buildCanvasImageToImageGraph = (
});
// add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS, model_loader);
// optionally add custom VAE
addVAEToGraph(state, graph);
addVAEToGraph(state, graph, model_loader);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@@ -15,6 +15,7 @@ import {
INPAINT_GRAPH,
ITERATE,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
RANDOM_INT,
@@ -63,6 +64,11 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model_loader = model.model_type.includes('onnx')
? ONNX_MODEL_LOADER
: MAIN_MODEL_LOADER;
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
nodes: {
@@ -107,9 +113,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
[model_loader]: {
type: model_loader,
id: model_loader,
model,
},
[CLIP_SKIP]: {
@@ -133,7 +139,7 @@ export const buildCanvasInpaintGraph = (
edges: [
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'unet',
},
destination: {
@@ -143,7 +149,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'clip',
},
destination: {

View File

@@ -10,6 +10,7 @@ import {
CLIP_SKIP,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
@@ -49,7 +50,8 @@ export const buildCanvasTextToImageGraph = (
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
const onnx_model_type = model.model_type.includes('onnx');
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -60,16 +62,17 @@ export const buildCanvasTextToImageGraph = (
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: TEXT_TO_IMAGE_GRAPH,
nodes: {
[POSITIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
@@ -81,15 +84,15 @@ export const buildCanvasTextToImageGraph = (
use_cpu,
},
[TEXT_TO_LATENTS]: {
type: 't2l',
type: onnx_model_type ? 't2l_onnx' : 't2l',
id: TEXT_TO_LATENTS,
cfg_scale,
scheduler,
steps,
},
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
[model_loader]: {
type: model_loader,
id: model_loader,
model,
},
[CLIP_SKIP]: {
@@ -98,7 +101,7 @@ export const buildCanvasTextToImageGraph = (
skipped_layers: clipSkip,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
},
},
@@ -125,7 +128,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'clip',
},
destination: {
@@ -155,7 +158,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'unet',
},
destination: {
@@ -219,10 +222,10 @@ export const buildCanvasTextToImageGraph = (
});
// add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
// optionally add custom VAE
addVAEToGraph(state, graph);
addVAEToGraph(state, graph, model_loader);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@@ -17,6 +17,7 @@ import {
LATENTS_TO_IMAGE,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
@@ -82,13 +83,17 @@ export const buildLinearImageToImageGraph = (
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
const onnx_model_type = model.model_type.includes('onnx');
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
// copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH,
nodes: {
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
[model_loader]: {
type: model_loader,
id: model_loader,
model,
},
[CLIP_SKIP]: {
@@ -97,12 +102,12 @@ export const buildLinearImageToImageGraph = (
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
@@ -112,11 +117,11 @@ export const buildLinearImageToImageGraph = (
use_cpu,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
type: onnx_model_type ? 'l2l_onnx' : 'l2l',
id: LATENTS_TO_LATENTS,
cfg_scale,
scheduler,
@@ -124,7 +129,7 @@ export const buildLinearImageToImageGraph = (
strength,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
type: onnx_model_type ? 'i2l_onnx' : 'i2l',
id: IMAGE_TO_LATENTS,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
@@ -135,7 +140,7 @@ export const buildLinearImageToImageGraph = (
edges: [
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'unet',
},
destination: {
@@ -145,7 +150,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'clip',
},
destination: {
@@ -366,10 +371,10 @@ export const buildLinearImageToImageGraph = (
});
// add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS, model_loader);
// optionally add custom VAE
addVAEToGraph(state, graph);
addVAEToGraph(state, graph, model_loader);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@@ -6,10 +6,13 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { BaseModelType, OnnxModelField } from 'services/api/types';
import {
CLIP_SKIP,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
@@ -46,6 +49,8 @@ export const buildLinearTextToImageGraph = (
throw new Error('No model found in state');
}
const onnx_model_type = model.model_type.includes('onnx');
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -56,12 +61,14 @@ export const buildLinearTextToImageGraph = (
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: TEXT_TO_IMAGE_GRAPH,
nodes: {
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
[model_loader]: {
type: model_loader,
id: model_loader,
model,
},
[CLIP_SKIP]: {
@@ -70,12 +77,12 @@ export const buildLinearTextToImageGraph = (
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
type: onnx_model_type ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
@@ -87,21 +94,21 @@ export const buildLinearTextToImageGraph = (
use_cpu,
},
[TEXT_TO_LATENTS]: {
type: 't2l',
type: onnx_model_type ? 't2l_onnx' : 't2l',
id: TEXT_TO_LATENTS,
cfg_scale,
scheduler,
steps,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
},
},
edges: [
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'clip',
},
destination: {
@@ -111,7 +118,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: MAIN_MODEL_LOADER,
node_id: model_loader,
field: 'unet',
},
destination: {
@@ -215,10 +222,10 @@ export const buildLinearTextToImageGraph = (
});
// add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
// optionally add custom VAE
addVAEToGraph(state, graph);
addVAEToGraph(state, graph, model_loader);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@@ -8,6 +8,7 @@ export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader';
export const CLIP_SKIP = 'clip_skip';
@@ -20,6 +21,9 @@ export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection';
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
export const REALESRGAN = 'esrgan';
export const DIVIDE = 'divide';
export const SCALE = 'scale_image';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@@ -0,0 +1,17 @@
import { BaseModelType, MainModelField, ModelType } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_type: model_type as ModelType,
model_name,
};
return field;
};

View File

@@ -0,0 +1,17 @@
import { BaseModelType, OnnxModelField, ModelType } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToOnnxModelField = (modelId: string): OnnxModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: OnnxModelField = {
base_model: base_model as BaseModelType,
model_name,
model_type: model_type as ModelType,
};
return field;
};

View File

@@ -1,34 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setCodeformerFidelity } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function CodeformerFidelity() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const codeformerFidelity = useAppSelector(
(state: RootState) => state.postprocessing.codeformerFidelity
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.codeformerFidelity')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setCodeformerFidelity(v))}
handleReset={() => dispatch(setCodeformerFidelity(1))}
value={codeformerFidelity}
withReset
withSliderMarks
withInput
/>
);
}

View File

@@ -1,25 +0,0 @@
import { VStack } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import type { RootState } from 'app/store/store';
import FaceRestoreType from './FaceRestoreType';
import FaceRestoreStrength from './FaceRestoreStrength';
import CodeformerFidelity from './CodeformerFidelity';
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const FaceRestoreSettings = () => {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
return (
<VStack gap={2} alignItems="stretch">
<FaceRestoreType />
<FaceRestoreStrength />
{facetoolType === 'codeformer' && <CodeformerFidelity />}
</VStack>
);
};
export default FaceRestoreSettings;

View File

@@ -1,34 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setFacetoolStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreStrength() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const facetoolStrength = useAppSelector(
(state: RootState) => state.postprocessing.facetoolStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.strength')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setFacetoolStrength(v))}
handleReset={() => dispatch(setFacetoolStrength(0.75))}
value={facetoolStrength}
withReset
withSliderMarks
withInput
/>
);
}

View File

@@ -1,28 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunFacetool } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function FaceRestoreToggle() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const shouldRunFacetool = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunFacetool
);
const dispatch = useAppDispatch();
const handleChangeShouldRunFacetool = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunFacetool(e.target.checked));
return (
<IAISwitch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunFacetool}
onChange={handleChangeShouldRunFacetool}
/>
);
}

View File

@@ -1,30 +0,0 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
FacetoolType,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreType() {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeFacetoolType = (v: string) =>
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAIMantineSearchableSelect
label={t('parameters.type')}
data={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>
);
}

View File

@@ -1,43 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ParamHiresStrength } from './ParamHiresStrength';
import { ParamHiresToggle } from './ParamHiresToggle';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamHiresCollapse = () => {
const { t } = useTranslation();
const { activeLabel } = useAppSelector(selector);
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
if (!isHiresEnabled) {
return null;
}
return (
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamHiresToggle />
<ParamHiresStrength />
</Flex>
</IAICollapse>
);
};
export default memo(ParamHiresCollapse);

View File

@@ -1,51 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { setHiresStrength } from 'features/parameters/store/postprocessingSlice';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
const hiresStrengthSelector = createSelector(
[postprocessingSelector],
({ hiresFix, hiresStrength }) => ({ hiresFix, hiresStrength }),
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ParamHiresStrength = () => {
const { hiresFix, hiresStrength } = useAppSelector(hiresStrengthSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleHiresStrength = (v: number) => {
dispatch(setHiresStrength(v));
};
const handleHiResStrengthReset = () => {
dispatch(setHiresStrength(0.75));
};
return (
<IAISlider
label={t('parameters.hiresStrength')}
step={0.01}
min={0.01}
max={0.99}
onChange={handleHiresStrength}
value={hiresStrength}
isInteger={false}
withInput
withSliderMarks
// inputWidth={22}
withReset
handleReset={handleHiResStrengthReset}
isDisabled={!hiresFix}
/>
);
};

View File

@@ -1,30 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
/**
* Hires Fix Toggle
*/
export const ParamHiresToggle = () => {
const dispatch = useAppDispatch();
const hiresFix = useAppSelector(
(state: RootState) => state.postprocessing.hiresFix
);
const { t } = useTranslation();
const handleChangeHiresFix = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setHiresFix(e.target.checked));
return (
<IAISwitch
label={t('parameters.hiresOptim')}
isChecked={hiresFix}
onChange={handleChangeHiresFix}
/>
);
};

View File

@@ -12,7 +12,11 @@ import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import { modelIdToOnnxModelField } from 'features/nodes/util/modelIdToOnnxModelField';
const selector = createSelector(
stateSelector,
@@ -27,6 +31,7 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery();
const data = useMemo(() => {
if (!mainModels) {
@@ -46,17 +51,31 @@ const ParamMainModelSelect = () => {
group: MODEL_TYPE_MAP[model.base_model],
});
});
forEach(onnxModels?.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [mainModels]);
}, [mainModels, onnxModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
(mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ||
onnxModels?.entities[
`${model?.base_model}/onnx/${model?.model_name}`
]) ??
null,
[mainModels?.entities, model]
[mainModels?.entities, model, onnxModels?.entities]
);
const handleChangeModel = useCallback(
@@ -65,7 +84,11 @@ const ParamMainModelSelect = () => {
return;
}
const newModel = modelIdToMainModelParam(v);
let newModel = modelIdToMainModelParam(v);
if (v.includes('onnx')) {
newModel = modelIdToOnnxModelField(v);
}
if (!newModel) {
return;
@@ -76,7 +99,7 @@ const ParamMainModelSelect = () => {
[dispatch]
);
return isLoading ? (
return isLoading || onnxLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."

View File

@@ -0,0 +1,58 @@
import { SelectItem } from '@mantine/core';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import {
ESRGANModelName,
esrganModelNameChanged,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export const ESRGAN_MODEL_NAMES: SelectItem[] = [
{
label: 'RealESRGAN x2 Plus',
value: 'RealESRGAN_x2plus.pth',
tooltip: 'Attempts to retain sharpness, low smoothing',
group: 'x2 Upscalers',
},
{
label: 'RealESRGAN x4 Plus',
value: 'RealESRGAN_x4plus.pth',
tooltip: 'Best for photos and highly detailed images, medium smoothing',
group: 'x4 Upscalers',
},
{
label: 'RealESRGAN x4 Plus (anime 6B)',
value: 'RealESRGAN_x4plus_anime_6B.pth',
tooltip: 'Best for anime/manga, high smoothing',
group: 'x4 Upscalers',
},
{
label: 'ESRGAN SRx4',
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
tooltip: 'Retains sharpness, low smoothing',
group: 'x4 Upscalers',
},
];
export default function ParamESRGANModel() {
const esrganModelName = useAppSelector(
(state: RootState) => state.postprocessing.esrganModelName
);
const dispatch = useAppDispatch();
const handleChange = (v: string) =>
dispatch(esrganModelNameChanged(v as ESRGANModelName));
return (
<IAIMantineSelect
label="ESRGAN Model"
value={esrganModelName}
itemComponent={IAIMantineSelectItemWithTooltip}
onChange={handleChange}
data={ESRGAN_MODEL_NAMES}
/>
);
}

View File

@@ -0,0 +1,62 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExpandArrowsAlt } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import ParamESRGANModel from './ParamRealESRGANModel';
type Props = { imageDTO?: ImageDTO };
const ParamUpscalePopover = (props: Props) => {
const { imageDTO } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const { isOpen, onOpen, onClose } = useDisclosure();
const handleClickUpscale = useCallback(() => {
onClose();
if (!imageDTO) {
return;
}
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
}, [dispatch, imageDTO, onClose]);
return (
<IAIPopover
isOpen={isOpen}
onClose={onClose}
triggerComponent={
<IAIIconButton
onClick={onOpen}
icon={<FaExpandArrowsAlt />}
aria-label={t('parameters.upscale')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
}}
>
<ParamESRGANModel />
<IAIButton
size="sm"
isDisabled={!imageDTO || isBusy}
onClick={handleClickUpscale}
>
{t('parameters.upscaleImage')}
</IAIButton>
</Flex>
</IAIPopover>
);
};
export default ParamUpscalePopover;

View File

@@ -1,36 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingDenoising } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleDenoisingStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingDenoising = useAppSelector(
(state: RootState) => state.postprocessing.upscalingDenoising
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={t('parameters.denoisingStrength')}
value={upscalingDenoising}
min={0}
max={1}
step={0.01}
onChange={(v) => {
dispatch(setUpscalingDenoising(v));
}}
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@@ -1,35 +0,0 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
UpscalingLevel,
setUpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleScale() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingLevel = useAppSelector(
(state: RootState) => state.postprocessing.upscalingLevel
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeLevel = (v: string) =>
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAIMantineSearchableSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={String(upscalingLevel)}
onChange={handleChangeLevel}
data={UPSCALING_LEVELS}
/>
);
}

View File

@@ -1,19 +0,0 @@
import { VStack } from '@chakra-ui/react';
import UpscaleDenoisingStrength from './UpscaleDenoisingStrength';
import UpscaleStrength from './UpscaleStrength';
import UpscaleScale from './UpscaleScale';
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const UpscaleSettings = () => {
return (
<VStack gap={2} alignItems="stretch">
<UpscaleScale />
<UpscaleDenoisingStrength />
<UpscaleStrength />
</VStack>
);
};
export default UpscaleSettings;

View File

@@ -1,33 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingStrength = useAppSelector(
(state: RootState) => state.postprocessing.upscalingStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={`${t('parameters.upscale')} ${t('parameters.strength')}`}
value={upscalingStrength}
min={0}
max={1}
step={0.05}
onChange={(v) => dispatch(setUpscalingStrength(v))}
handleReset={() => dispatch(setUpscalingStrength(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@@ -1,26 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunESRGAN } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function UpscaleToggle() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const shouldRunESRGAN = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunESRGAN
);
const dispatch = useAppDispatch();
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
return (
<IAISwitch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
);
}

View File

@@ -1,10 +1,10 @@
import { createAction } from '@reduxjs/toolkit';
import { ImageDTO, MainModelField } from 'services/api/types';
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected'
);
export const modelSelected = createAction<MainModelField>(
export const modelSelected = createAction<MainModelField | OnnxModelField>(
'generation/modelSelected'
);

View File

@@ -8,7 +8,7 @@ import {
setShouldShowAdvancedOptions,
} from 'features/ui/store/uiSlice';
import { clamp } from 'lodash-es';
import { ImageDTO, MainModelField } from 'services/api/types';
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
@@ -54,7 +54,7 @@ export interface GenerationState {
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: MainModelField | null;
model: MainModelField | OnnxModelField | null;
vae: VaeModelParam | null;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
@@ -227,7 +227,10 @@ export const generationSlice = createSlice({
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
},
modelChanged: (state, action: PayloadAction<MainModelParam | null>) => {
modelChanged: (
state,
action: PayloadAction<MainModelParam | OnnxModelField | null>
) => {
state.model = action.payload;
if (state.model === null) {
@@ -259,6 +262,7 @@ export const generationSlice = createSlice({
const result = zMainModel.safeParse({
model_name,
base_model,
model_type,
});
if (result.success) {

View File

@@ -1,98 +1,27 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { FACETOOL_TYPES } from 'app/constants';
import { ESRGANInvocation } from 'services/api/types';
export type UpscalingLevel = 2 | 4;
export type FacetoolType = (typeof FACETOOL_TYPES)[number];
export type ESRGANModelName = NonNullable<ESRGANInvocation['model_name']>;
export interface PostprocessingState {
codeformerFidelity: number;
facetoolStrength: number;
facetoolType: FacetoolType;
hiresFix: boolean;
hiresStrength: number;
shouldLoopback: boolean;
shouldRunESRGAN: boolean;
shouldRunFacetool: boolean;
upscalingLevel: UpscalingLevel;
upscalingDenoising: number;
upscalingStrength: number;
esrganModelName: ESRGANModelName;
}
export const initialPostprocessingState: PostprocessingState = {
codeformerFidelity: 0.75,
facetoolStrength: 0.75,
facetoolType: 'gfpgan',
hiresFix: false,
hiresStrength: 0.75,
shouldLoopback: false,
shouldRunESRGAN: false,
shouldRunFacetool: false,
upscalingLevel: 4,
upscalingDenoising: 0.75,
upscalingStrength: 0.75,
esrganModelName: 'RealESRGAN_x4plus.pth',
};
export const postprocessingSlice = createSlice({
name: 'postprocessing',
initialState: initialPostprocessingState,
reducers: {
setFacetoolStrength: (state, action: PayloadAction<number>) => {
state.facetoolStrength = action.payload;
},
setCodeformerFidelity: (state, action: PayloadAction<number>) => {
state.codeformerFidelity = action.payload;
},
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
state.upscalingLevel = action.payload;
},
setUpscalingDenoising: (state, action: PayloadAction<number>) => {
state.upscalingDenoising = action.payload;
},
setUpscalingStrength: (state, action: PayloadAction<number>) => {
state.upscalingStrength = action.payload;
},
setHiresFix: (state, action: PayloadAction<boolean>) => {
state.hiresFix = action.payload;
},
setHiresStrength: (state, action: PayloadAction<number>) => {
state.hiresStrength = action.payload;
},
resetPostprocessingState: (state) => {
return {
...state,
...initialPostprocessingState,
};
},
setShouldRunFacetool: (state, action: PayloadAction<boolean>) => {
state.shouldRunFacetool = action.payload;
},
setFacetoolType: (state, action: PayloadAction<FacetoolType>) => {
state.facetoolType = action.payload;
},
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
state.shouldRunESRGAN = action.payload;
},
setShouldLoopback: (state, action: PayloadAction<boolean>) => {
state.shouldLoopback = action.payload;
esrganModelNameChanged: (state, action: PayloadAction<ESRGANModelName>) => {
state.esrganModelName = action.payload;
},
},
});
export const {
resetPostprocessingState,
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setHiresStrength,
setShouldLoopback,
setShouldRunESRGAN,
setShouldRunFacetool,
setUpscalingLevel,
setUpscalingDenoising,
setUpscalingStrength,
} = postprocessingSlice.actions;
export const { esrganModelNameChanged } = postprocessingSlice.actions;
export default postprocessingSlice.reducer;

View File

@@ -126,6 +126,14 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zModelType = z.enum([
'vae',
'lora',
'onnx',
'main',
'controlnet',
'embedding',
]);
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
@@ -137,6 +145,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
export const zMainModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
model_type: zModelType,
});
/**

View File

@@ -14,6 +14,7 @@ export const modelIdToMainModelParam = (
const result = zMainModel.safeParse({
base_model,
model_name,
model_type,
});
if (!result.success) {

View File

@@ -5,7 +5,11 @@ import { useRef } from 'react';
import { useHoverDirty } from 'react-use';
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
const InvokeAILogoComponent = () => {
interface Props {
showVersion?: boolean;
}
const InvokeAILogoComponent = ({ showVersion = true }: Props) => {
const { data: appVersion } = useGetAppVersionQuery();
const ref = useRef(null);
const isHovered = useHoverDirty(ref);
@@ -28,7 +32,7 @@ const InvokeAILogoComponent = () => {
invoke <strong>ai</strong>
</Text>
<AnimatePresence>
{isHovered && appVersion && (
{showVersion && isHovered && appVersion && (
<motion.div
key="statusText"
initial={{

View File

@@ -0,0 +1,119 @@
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { modelIdToMainModelField } from 'features/nodes/util/modelIdToMainModelField';
import { modelSelected } from 'features/parameters/store/actions';
import { forEach } from 'lodash-es';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import { modelIdToOnnxModelField } from 'features/nodes/util/modelIdToOnnxModelField';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};
const selector = createSelector(
stateSelector,
(state) => ({ currentModel: state.generation.model }),
defaultSelectorOptions
);
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { currentModel } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery();
const data = useMemo(() => {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
forEach(onnxModels?.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [mainModels, onnxModels]);
const selectedModel = useMemo(
() =>
mainModels?.entities[
`${currentModel?.base_model}/main/${currentModel?.model_name}`
] ||
onnxModels?.entities[
`${currentModel?.base_model}/onnx/${currentModel?.model_name}`
],
[mainModels?.entities, onnxModels?.entities, currentModel]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
let modelField = modelIdToMainModelField(v);
if (v.includes('onnx')) {
modelField = modelIdToOnnxModelField(v);
}
dispatch(modelSelected(modelField));
},
[dispatch]
);
return isLoading || onnxLoading ? (
<IAIMantineSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
);
};
export default memo(ModelSelect);

View File

@@ -0,0 +1,56 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion1ModelCheckpointConfig,
StableDiffusion1ModelDiffusersConfig,
} from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type SD1PipelineModel = (
| StableDiffusion1ModelCheckpointConfig
| StableDiffusion1ModelDiffusersConfig
) & {
name: string;
};
export const sd1PipelineModelsAdapter = createEntityAdapter<SD1PipelineModel>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd1InitialPipelineModelsState =
sd1PipelineModelsAdapter.getInitialState();
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
export const sd1PipelineModelsSlice = createSlice({
name: 'sd1PipelineModels',
initialState: sd1InitialPipelineModelsState,
reducers: {
modelAdded: sd1PipelineModelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-1') return;
sd1PipelineModelsAdapter.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD1PipelineModels,
selectById: selectByIdSD1PipelineModels,
selectEntities: selectEntitiesSD1PipelineModels,
selectIds: selectIdsSD1PipelineModels,
selectTotal: selectTotalSD1PipelineModels,
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
(state) => state.sd1pipelinemodels
);
export const { modelAdded } = sd1PipelineModelsSlice.actions;
export default sd1PipelineModelsSlice.reducer;

View File

@@ -0,0 +1,56 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion2ModelCheckpointConfig,
StableDiffusion2ModelDiffusersConfig,
} from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type SD2PipelineModel = (
| StableDiffusion2ModelCheckpointConfig
| StableDiffusion2ModelDiffusersConfig
) & {
name: string;
};
export const sd2PipelineModelsAdapater = createEntityAdapter<SD2PipelineModel>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd2InitialPipelineModelsState =
sd2PipelineModelsAdapater.getInitialState();
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
export const sd2PipelineModelsSlice = createSlice({
name: 'sd2PipelineModels',
initialState: sd2InitialPipelineModelsState,
reducers: {
modelAdded: sd2PipelineModelsAdapater.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-2') return;
sd2PipelineModelsAdapater.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD2PipelineModels,
selectById: selectByIdSD2PipelineModels,
selectEntities: selectEntitiesSD2PipelineModels,
selectIds: selectIdsSD2PipelineModels,
selectTotal: selectTotalSD2PipelineModels,
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
(state) => state.sd2pipelinemodels
);
export const { modelAdded } = sd2PipelineModelsSlice.actions;
export default sd2PipelineModelsSlice.reducer;

View File

@@ -17,7 +17,7 @@ type ModelListProps = {
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;

View File

@@ -4,7 +4,6 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
@@ -26,7 +25,6 @@ const TextToImageTabParameters = () => {
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />
<ParamHiresCollapse />
<ParamSeamlessCollapse />
<ParamAdvancedCollapse />
</>

View File

@@ -10,6 +10,7 @@ import {
ImportModelConfig,
LoRAModelConfig,
MainModelConfig,
OnnxModelConfig,
MergeModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
@@ -27,6 +28,8 @@ export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
@@ -41,6 +44,7 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
type AnyModelConfigEntity =
| MainModelConfigEntity
| OnnxModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| TextualInversionModelConfigEntity
@@ -104,6 +108,10 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@@ -141,6 +149,38 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getOnnxModels: build.query<EntityState<OnnxModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'onnx' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'OnnxModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'OnnxModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: OnnxModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<OnnxModelConfigEntity>(
response.models
);
return onnxModelsAdapter.setAll(
onnxModelsAdapter.getInitialState(),
entities
);
},
}),
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => {
@@ -413,6 +453,7 @@ export const modelsApi = api.injectEndpoints({
export const {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,

View File

@@ -8,6 +8,132 @@ import {
} from '@reduxjs/toolkit/query/react';
import { $authToken, $baseUrl } from 'services/api/client';
export type { AddInvocation } from './models/AddInvocation';
export type { BoardChanges } from './models/BoardChanges';
export type { BoardDTO } from './models/BoardDTO';
export type { Body_create_board_image } from './models/Body_create_board_image';
export type { Body_remove_board_image } from './models/Body_remove_board_image';
export type { Body_upload_image } from './models/Body_upload_image';
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo';
export type { ClipField } from './models/ClipField';
export type { CollectInvocation } from './models/CollectInvocation';
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
export type { ColorField } from './models/ColorField';
export type { CompelInvocation } from './models/CompelInvocation';
export type { CompelOutput } from './models/CompelOutput';
export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation';
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation';
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
export type { FloatLinearRangeInvocation } from './models/FloatLinearRangeInvocation';
export type { FloatOutput } from './models/FloatOutput';
export type { Graph } from './models/Graph';
export type { GraphExecutionState } from './models/GraphExecutionState';
export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
export type { ImageChannelInvocation } from './models/ImageChannelInvocation';
export type { ImageConvertInvocation } from './models/ImageConvertInvocation';
export type { ImageCropInvocation } from './models/ImageCropInvocation';
export type { ImageDTO } from './models/ImageDTO';
export type { ImageField } from './models/ImageField';
export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation';
export type { ImageLerpInvocation } from './models/ImageLerpInvocation';
export type { ImageMetadata } from './models/ImageMetadata';
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
export type { ImageOutput } from './models/ImageOutput';
export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation';
export type { ImageRecordChanges } from './models/ImageRecordChanges';
export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
export type { InfillColorInvocation } from './models/InfillColorInvocation';
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
export type { LatentsOutput } from './models/LatentsOutput';
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { LoraInfo } from './models/LoraInfo';
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
export type { ModelInfo } from './models/ModelInfo';
export type { ModelLoaderOutput } from './models/ModelLoaderOutput';
export type { ModelsList } from './models/ModelsList';
export type { ModelType } from './models/ModelType';
export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
export type { ONNXLatentsToImageInvocation } from './models/ONNXLatentsToImageInvocation';
export type { ONNXModelLoaderOutput } from './models/ONNXModelLoaderOutput';
export type { ONNXPromptInvocation } from './models/ONNXPromptInvocation';
export type { ONNXSD1ModelLoaderInvocation } from './models/ONNXSD1ModelLoaderInvocation';
export type { ONNXStableDiffusion1ModelConfig } from './models/ONNXStableDiffusion1ModelConfig';
export type { ONNXStableDiffusion2ModelConfig } from './models/ONNXStableDiffusion2ModelConfig';
export type { ONNXTextToLatentsInvocation } from './models/ONNXTextToLatentsInvocation';
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PipelineModelField } from './models/PipelineModelField';
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
export type { RangeInvocation } from './models/RangeInvocation';
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation';
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubModelType } from './models/SubModelType';
export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
export type { UNetField } from './models/UNetField';
export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeField } from './models/VaeField';
export type { VaeModelConfig } from './models/VaeModelConfig';
export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
export const tagTypes = ['Board', 'Image', 'ImageMetadata', 'Model'];
export type ApiFullTagDescription = FullTagDescription<
(typeof tagTypes)[number]

View File

@@ -0,0 +1,26 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Adds two numbers
*/
export type AddInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'add';
/**
* The first number
*/
'a'?: number;
/**
* The second number
*/
'b'?: number;
};

View File

@@ -0,0 +1,14 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type BoardChanges = {
/**
* The board's new name.
*/
board_name?: string;
/**
* The name of the board's new cover image.
*/
cover_image_name?: string;
};

View File

@@ -0,0 +1,37 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Deserialized board record with cover image URL and image count.
*/
export type BoardDTO = {
/**
* The unique ID of the board.
*/
board_id: string;
/**
* The name of the board.
*/
board_name: string;
/**
* The created timestamp of the board.
*/
created_at: string;
/**
* The updated timestamp of the board.
*/
updated_at: string;
/**
* The deleted timestamp of the board.
*/
deleted_at?: string;
/**
* The name of the board's cover image.
*/
cover_image_name?: string;
/**
* The number of images in the board.
*/
image_count: number;
};

View File

@@ -0,0 +1,14 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_create_board_image = {
/**
* The id of the board to add to
*/
board_id: string;
/**
* The name of the image to add
*/
image_name: string;
};

View File

@@ -0,0 +1,14 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_remove_board_image = {
/**
* The id of the board
*/
board_id: string;
/**
* The name of the image to remove
*/
image_name: string;
};

View File

@@ -0,0 +1,7 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_upload_image = {
file: Blob;
};

View File

@@ -0,0 +1,32 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Canny edge detection for ControlNet
*/
export type CannyImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'canny_image_processor';
/**
* The image to process
*/
image?: ImageField;
/**
* The low threshold of the Canny pixel gradient (0-255)
*/
low_threshold?: number;
/**
* The high threshold of the Canny pixel gradient (0-255)
*/
high_threshold?: number;
};

View File

@@ -0,0 +1,39 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type CkptModelInfo = {
/**
* A description of the model
*/
description?: string;
/**
* The name of the model
*/
model_name: string;
/**
* The type of the model
*/
model_type: string;
format?: 'ckpt';
/**
* The path to the model config
*/
config: string;
/**
* The path to the model weights
*/
weights: string;
/**
* The path to the model VAE
*/
vae: string;
/**
* The width of the model
*/
width?: number;
/**
* The height of the model
*/
height?: number;
};

View File

@@ -0,0 +1,21 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { LoraInfo } from './LoraInfo';
import type { ModelInfo } from './ModelInfo';
export type ClipField = {
/**
* Info to load tokenizer submodel
*/
tokenizer: ModelInfo;
/**
* Info to load text_encoder submodel
*/
text_encoder: ModelInfo;
/**
* Loras to apply on model loading
*/
loras: Array<LoraInfo>;
};

View File

@@ -0,0 +1,26 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Collects values into a collection
*/
export type CollectInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'collect';
/**
* The item to collect (all inputs must be of the same type)
*/
item?: any;
/**
* The collection, will be provided on execution
*/
collection?: Array<any>;
};

View File

@@ -0,0 +1,14 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Base class for all invocation outputs
*/
export type CollectInvocationOutput = {
type: 'collect_output';
/**
* The collection of input items
*/
collection: Array<any>;
};

View File

@@ -0,0 +1,22 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type ColorField = {
/**
* The red component
*/
'r': number;
/**
* The green component
*/
'g': number;
/**
* The blue component
*/
'b': number;
/**
* The alpha component
*/
'a': number;
};

View File

@@ -0,0 +1,28 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ClipField } from './ClipField';
/**
* Parse prompt using compel package to conditioning.
*/
export type CompelInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'compel';
/**
* Prompt
*/
prompt?: string;
/**
* Clip to use
*/
clip?: ClipField;
};

View File

@@ -0,0 +1,16 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ConditioningField } from './ConditioningField';
/**
* Compel parser output
*/
export type CompelOutput = {
type?: 'compel_output';
/**
* Conditioning
*/
conditioning?: ConditioningField;
};

View File

@@ -0,0 +1,10 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type ConditioningField = {
/**
* The name of conditioning data
*/
conditioning_name: string;
};

View File

@@ -0,0 +1,44 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies content shuffle processing to image
*/
export type ContentShuffleImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'content_shuffle_image_processor';
/**
* The image to process
*/
image?: ImageField;
/**
* The pixel resolution for detection
*/
detect_resolution?: number;
/**
* The pixel resolution for the output image
*/
image_resolution?: number;
/**
* Content shuffle `h` parameter
*/
'h'?: number;
/**
* Content shuffle `w` parameter
*/
'w'?: number;
/**
* Content shuffle `f` parameter
*/
'f'?: number;
};

View File

@@ -0,0 +1,28 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
export type ControlField = {
/**
* The control image
*/
image: ImageField;
/**
* The ControlNet model to use
*/
control_model: string;
/**
* The weight given to the ControlNet
*/
control_weight: (number | Array<number>);
/**
* When the ControlNet is first applied (% of total steps)
*/
begin_step_percent: number;
/**
* When the ControlNet is last applied (% of total steps)
*/
end_step_percent: number;
};

View File

@@ -0,0 +1,40 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Collects ControlNet info to pass to other nodes
*/
export type ControlNetInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'controlnet';
/**
* The control image
*/
image?: ImageField;
/**
* control model used
*/
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
/**
* The weight given to the ControlNet
*/
control_weight?: (number | Array<number>);
/**
* When the ControlNet is first applied (% of total steps)
*/
begin_step_percent?: number;
/**
* When the ControlNet is last applied (% of total steps)
*/
end_step_percent?: number;
};

View File

@@ -0,0 +1,17 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ControlNetModelFormat } from './ControlNetModelFormat';
import type { ModelError } from './ModelError';
export type ControlNetModelConfig = {
name: string;
base_model: BaseModelType;
type: 'controlnet';
path: string;
description?: string;
model_format: ControlNetModelFormat;
error?: ModelError;
};

View File

@@ -0,0 +1,16 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ControlField } from './ControlField';
/**
* node output for ControlNet info
*/
export type ControlOutput = {
type?: 'control_output';
/**
* The control info
*/
control?: ControlField;
};

View File

@@ -0,0 +1,17 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { CkptModelInfo } from './CkptModelInfo';
import type { DiffusersModelInfo } from './DiffusersModelInfo';
export type CreateModelRequest = {
/**
* The name of the model
*/
name: string;
/**
* The model info
*/
info: (CkptModelInfo | DiffusersModelInfo);
};

View File

@@ -0,0 +1,28 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Simple inpaint using opencv.
*/
export type CvInpaintInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'cv_inpaint';
/**
* The image to inpaint
*/
image?: ImageField;
/**
* The mask to use when inpainting
*/
mask?: ImageField;
};

View File

@@ -0,0 +1,33 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { VaeRepo } from './VaeRepo';
export type DiffusersModelInfo = {
/**
* A description of the model
*/
description?: string;
/**
* The name of the model
*/
model_name: string;
/**
* The type of the model
*/
model_type: string;
format?: 'folder';
/**
* The VAE repo to use for this model
*/
vae?: VaeRepo;
/**
* The repo ID to use for this model
*/
repo_id?: string;
/**
* The path to the model
*/
path?: string;
};

Some files were not shown because too many files have changed in this diff Show More