mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 09:57:58 -05:00
Compare commits
8 Commits
v3.0.2
...
chore/pyda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a532d00ec | ||
|
|
a67d8376c7 | ||
|
|
0b11f309ca | ||
|
|
6a8eb392b2 | ||
|
|
824ca92760 | ||
|
|
80fd4c2176 | ||
|
|
3b6e425e17 | ||
|
|
50415450d8 |
@@ -8,16 +8,13 @@ Preparations:
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
NOTE: At this time we do not recommend Python 3.11. We recommend
|
||||
Version 3.10.9, which has been extensively tested with InvokeAI.
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.9.1, and less than 3.11.
|
||||
is at least 3.9.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ async def update_model(
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
@@ -203,7 +203,7 @@ async def add_model(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.model_dump()
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@@ -348,7 +348,7 @@ async def sync_to_config() -> bool:
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
model_names: List[str] = Body(description="model name", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
#from pydantic.schema import schema
|
||||
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
@@ -126,13 +126,13 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
# output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
# for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
# openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
# # TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# # This could break in some cases, figure out a better way to do it
|
||||
# output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
|
||||
@@ -67,7 +67,7 @@ def add_parsers(
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
fields = command.model_fields # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
@@ -87,7 +87,7 @@ def add_graph_parsers(
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
field = node.model_fields[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
@@ -194,7 +194,7 @@ def get_graph_execution_history(
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
fields = invocation.model_fields.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
|
||||
@@ -118,7 +118,7 @@ class CustomisedSchemaExtra(TypedDict):
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
Provide `json_schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
@@ -131,7 +131,7 @@ class InvocationConfig(BaseConfig):
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
@@ -142,4 +142,4 @@ class InvocationConfig(BaseConfig):
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
json_schema_extra: CustomisedSchemaExtra
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from pydantic import Field, field_field_validator
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@@ -38,7 +38,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
json_schema_extra = {"required": ["type", "collection"]}
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
@@ -52,11 +52,11 @@ class RangeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||
}
|
||||
|
||||
@validator("stop")
|
||||
@field_validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
@@ -77,7 +77,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
|
||||
@@ -26,7 +26,7 @@ class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
json_schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,18 +80,18 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@@ -178,11 +178,11 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@@ -255,11 +255,11 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@@ -360,7 +360,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
@@ -471,7 +471,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@@ -525,7 +525,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
@@ -580,7 +580,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
@@ -123,7 +123,7 @@ class ControlField(BaseModel):
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
@field_validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
@@ -136,7 +136,7 @@ class ControlField(BaseModel):
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
@@ -176,7 +176,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
@@ -214,7 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -290,7 +290,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -342,7 +342,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||
@@ -371,7 +371,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -426,7 +426,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -451,7 +451,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -480,7 +480,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -510,7 +510,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||
@@ -539,7 +539,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -560,7 +560,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -588,7 +588,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@@ -614,7 +614,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||
@@ -656,7 +656,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Segment Anything Processor",
|
||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||
|
||||
@@ -17,7 +17,7 @@ class CvInvocationConfig(BaseModel):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
@@ -36,7 +36,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@@ -169,11 +169,11 @@ class InpaintInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class LoadImageInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||
}
|
||||
|
||||
@@ -343,7 +343,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||
}
|
||||
|
||||
@@ -405,7 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||
}
|
||||
|
||||
@@ -448,7 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||
}
|
||||
|
||||
@@ -493,7 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||
}
|
||||
|
||||
@@ -534,7 +534,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||
@@ -577,7 +577,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||
}
|
||||
|
||||
@@ -600,7 +600,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -629,7 +629,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||
}
|
||||
|
||||
@@ -643,7 +643,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -125,7 +125,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from diffusers.models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
@@ -46,7 +46,7 @@ class LatentsField(BaseModel):
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
json_schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
@@ -80,7 +80,7 @@ def get_scheduler(
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict(),
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@@ -121,7 +121,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@@ -135,7 +135,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Text To Latents",
|
||||
"tags": ["latents"],
|
||||
@@ -158,7 +158,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@@ -327,7 +327,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
@@ -384,7 +384,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latent To Latents",
|
||||
"tags": ["latents"],
|
||||
@@ -420,7 +420,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
@@ -495,7 +495,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
@@ -507,7 +507,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@@ -565,7 +565,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -593,7 +593,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||
}
|
||||
|
||||
@@ -634,7 +634,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||
}
|
||||
|
||||
@@ -675,7 +675,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||
}
|
||||
|
||||
@@ -686,9 +686,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.model_dump())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
@@ -53,7 +53,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ class RandomIntInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
@@ -152,4 +152,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||
|
||||
@@ -73,7 +73,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
@@ -205,7 +205,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
@@ -287,7 +287,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
@@ -385,7 +385,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, validator
|
||||
from pydantic import Field, field_validator
|
||||
import torch
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
|
||||
@@ -110,14 +110,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
@field_validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
@@ -59,10 +59,10 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
loras = [
|
||||
@@ -154,7 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@@ -168,7 +168,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
@@ -226,7 +226,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@@ -239,7 +239,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
@@ -314,7 +314,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
@@ -327,7 +327,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
@@ -356,7 +356,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -389,7 +389,7 @@ class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||
}
|
||||
|
||||
@@ -472,7 +472,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
|
||||
@@ -65,7 +65,7 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class ParamIntInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ class ParamFloatInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ class ParamStringInvocation(BaseInvocation):
|
||||
text: str = Field(default="", description="The string value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ class ParamPromptInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="", description="The prompt value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from os.path import exists
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
@@ -18,7 +18,7 @@ class PromptOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"prompt",
|
||||
@@ -37,7 +37,7 @@ class PromptCollectionOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
json_schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
|
||||
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
@@ -49,7 +49,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||
}
|
||||
|
||||
@@ -79,11 +79,11 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||
}
|
||||
|
||||
@validator("file_path")
|
||||
@field_validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
if not exists(v):
|
||||
raise ValueError(FileNotFoundError)
|
||||
|
||||
@@ -3,7 +3,7 @@ import inspect
|
||||
from tqdm import tqdm
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, validator
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
@@ -49,7 +49,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
@@ -139,7 +139,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
@@ -224,7 +224,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@@ -238,7 +238,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
@@ -260,7 +260,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
@@ -303,7 +303,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
@@ -481,7 +481,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@@ -495,7 +495,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
@@ -517,7 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
@@ -549,7 +549,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class ESRGANInvocation(BaseInvocation):
|
||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class ImageField(BaseModel):
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
json_schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
@@ -40,7 +40,7 @@ class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
@@ -58,7 +58,7 @@ class ImageOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
json_schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
@@ -72,7 +72,7 @@ class MaskOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
|
||||
@@ -166,7 +166,8 @@ import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
@@ -187,7 +188,7 @@ class InvokeAISettings(BaseSettings):
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
for name in self.model_fields:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
@@ -204,7 +205,7 @@ class InvokeAISettings(BaseSettings):
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
@@ -238,7 +239,7 @@ class InvokeAISettings(BaseSettings):
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
|
||||
@@ -44,7 +44,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic import BaseModel, model_validator, field_validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -156,7 +156,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["graph_output"] = "graph_output"
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
@@ -186,7 +186,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"item",
|
||||
@@ -214,7 +214,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"collection",
|
||||
@@ -755,7 +755,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
@@ -1110,13 +1110,13 @@ class LibraryGraph(BaseModel):
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@validator("exposed_inputs", "exposed_outputs")
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
@model_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values["graph"]
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
class PaginatedResults(BaseModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
|
||||
# fmt: off
|
||||
|
||||
@@ -238,7 +238,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
@@ -568,7 +568,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
|
||||
@@ -89,7 +89,7 @@ def image_record_to_dto(
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
**image_record.model_dump(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
|
||||
@@ -80,7 +80,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@@ -107,9 +107,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
|
||||
@@ -134,7 +134,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
@@ -155,7 +155,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
|
||||
@@ -716,7 +716,7 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
fields = [x for x, y in InvokeAIAppConfig.model_fields.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
|
||||
@@ -897,7 +897,7 @@ class ModelManager(object):
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import bisect
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Union, Literal, Any
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
|
||||
@@ -482,30 +485,61 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = dict()
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
continue # clip same
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if not full_key.startswith("lora_unet_"):
|
||||
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||
src_key = full_key.replace("lora_unet_", "")
|
||||
try:
|
||||
dst_key = None
|
||||
while "_" in src_key:
|
||||
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||
break
|
||||
src_key = "_".join(src_key.split("_")[:-1])
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
if dst_key is None:
|
||||
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||
new_key = full_key.replace(src_key, dst_key)
|
||||
except:
|
||||
print(SDXL_UNET_COMPVIS_MAP)
|
||||
raise
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
@@ -537,7 +571,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
@@ -588,6 +622,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
@@ -671,7 +706,6 @@ def make_sdxl_unet_conversion_map():
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_COMPVIS_MAP = {
|
||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||
for sd, hf in make_sdxl_unet_conversion_map()
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import invokeai.backend.util.logging as logger
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
@@ -12,6 +11,7 @@ from .base import (
|
||||
DiffusersModel,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@@ -65,7 +65,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
@@ -20,7 +20,7 @@ export const addStagingAreaImageSavedListener = () => {
|
||||
// we may need to add it to the autoadd board
|
||||
const { autoAddBoardId } = getState().gallery;
|
||||
|
||||
if (autoAddBoardId) {
|
||||
if (autoAddBoardId && autoAddBoardId !== 'none') {
|
||||
await dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
imageDTO: newImageDTO,
|
||||
|
||||
@@ -365,12 +365,19 @@ export const systemSlice = createSlice({
|
||||
state.statusTranslationKey = 'common.statusConnected';
|
||||
state.progressImage = null;
|
||||
|
||||
let errorDescription = undefined;
|
||||
|
||||
if (action.payload?.status === 422) {
|
||||
errorDescription = 'Validation Error';
|
||||
} else if (action.payload?.error) {
|
||||
errorDescription = action.payload?.error as string;
|
||||
}
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description:
|
||||
action.payload?.status === 422 ? 'Validation Error' : undefined,
|
||||
description: errorDescription,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
@@ -60,6 +60,9 @@ type InvokedSessionThunkConfig = {
|
||||
const isErrorWithStatus = (error: unknown): error is { status: number } =>
|
||||
isObject(error) && 'status' in error;
|
||||
|
||||
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
|
||||
isObject(error) && 'detail' in error;
|
||||
|
||||
/**
|
||||
* `SessionsService.invokeSession()` thunk
|
||||
*/
|
||||
@@ -85,7 +88,15 @@ export const sessionInvoked = createAsyncThunk<
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
if (isErrorWithDetail(error) && response.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: error.detail
|
||||
});
|
||||
}
|
||||
if (error)
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ dependencies = [
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
"fastapi==0.88.0",
|
||||
"fastapi==0.101.0",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi-socketio==0.0.10",
|
||||
"flask==2.1.3",
|
||||
@@ -64,7 +64,8 @@ dependencies = [
|
||||
"onnx",
|
||||
"onnxruntime",
|
||||
"opencv-python",
|
||||
"pydantic==1.*",
|
||||
"pydantic==2.1.1",
|
||||
"pydantic-settings",
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
|
||||
@@ -584,7 +584,7 @@ def test_graph_can_serialize():
|
||||
g.add_edge(e)
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
json = g.json()
|
||||
json = g.model_dump_json()
|
||||
|
||||
|
||||
def test_graph_can_deserialize():
|
||||
@@ -596,8 +596,8 @@ def test_graph_can_deserialize():
|
||||
e = create_edge(n1.id, "image", n2.id, "image")
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.json()
|
||||
g2 = Graph.parse_raw(json)
|
||||
json = g.model_dump_json()
|
||||
g2 = Graph.model_validate_json(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
|
||||
Reference in New Issue
Block a user