Compare commits

..

6 Commits

Author SHA1 Message Date
psychedelicious
3707c3b034 fix(ui): do not bake opacity when rasterizing layer adjustments 2025-09-22 11:43:08 +10:00
Mary Hipp
5885db4ab5 ruff 2025-09-19 11:07:36 -04:00
Mary Hipp
36ed9b750d restore list_queue_items method 2025-09-19 11:07:36 -04:00
psychedelicious
3cec06f86e chore(ui): typegen 2025-09-19 22:13:12 +10:00
psychedelicious
28b5f7a1c5 feat(nodes): better deprecation handling for ui_type
- Move migration of model-specific ui_types into BaseInvocation. This
gives us access to the node and field names, so the warnings are more
useful to the end user.
- Ensure we serialize the fields' json_schema_extra with enum values.
This wasn't a problem until now, when it interferes with migrating
ui_type cleanly. It's a transparent change.
- Improve warnings when validating fields (which includes the ui_type
migration logic)
2025-09-19 22:13:12 +10:00
psychedelicious
22cbb23ae0 fix(ui): ref images for flux kontext & api models not parsed correctly 2025-09-19 21:40:17 +10:00
25 changed files with 281 additions and 1001 deletions

View File

@@ -1,128 +0,0 @@
# Qwen-Image Implementation for InvokeAI
## Overview
This implementation adds support for the Qwen-Image family of models to InvokeAI. Qwen-Image is a 20B parameter Multimodal Diffusion Transformer (MMDiT) model that excels at complex text rendering and precise image editing.
## Model Setup
### 1. Download the Qwen-Image Model
```bash
# Option 1: Using git (recommended for large models)
git clone https://huggingface.co/Qwen/Qwen-Image invokeai/models/qwen-image/Qwen-Image
# Option 2: Using huggingface-cli
huggingface-cli download Qwen/Qwen-Image --local-dir invokeai/models/qwen-image/Qwen-Image
```
### 2. Download Qwen2.5-VL Text Encoder
Qwen-Image uses Qwen2.5-VL-7B as its text encoder (not CLIP):
```bash
git clone https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct invokeai/models/qwen-image/Qwen2.5-VL-7B-Instruct
```
## Model Architecture
### Components
1. **Transformer**: QwenImageTransformer2DModel (MMDiT architecture, 20B parameters)
2. **Text Encoder**: Qwen2.5-VL-7B-Instruct (7B parameter vision-language model)
3. **VAE**: AutoencoderKLQwenImage (bundled with main model in `/vae` subdirectory)
4. **Scheduler**: FlowMatchEulerDiscreteScheduler
### Key Features
- **Complex Text Rendering**: Superior ability to render text accurately in images
- **Bundled VAE**: The model includes its own custom VAE (no separate download needed)
- **Large Text Encoder**: Uses a 7B parameter VLM instead of traditional CLIP
- **Optional VAE Override**: Can use custom VAE models if desired
## Components Implemented
### Backend Components
1. **Model Taxonomy** (`taxonomy.py`): Added `QwenImage = "qwen-image"` base model type
2. **Model Configuration** (`config.py`): Uses MainDiffusersConfig for Qwen-Image models
3. **Model Loader** (`qwen_image.py`): Loads models and submodels via diffusers
4. **Model Loader Node** (`qwen_image_model_loader.py`): Loads transformer, text encoder, and VAE
5. **Text Encoder Node** (`qwen_image_text_encoder.py`): Encodes prompts using Qwen2.5-VL
6. **Denoising Node** (`qwen_image_denoise.py`): Generates images using QwenImagePipeline
### Frontend Components
1. **UI Types**: Added QwenImageMainModel, Qwen2_5VLModel field types
2. **Field Components**: Created input components for model selection
3. **Type Guards**: Added model detection and filtering functions
4. **Hooks**: Model loading hooks for UI dropdowns
## Dependencies Updated
- Updated `pyproject.toml` to use `diffusers[torch]==0.35.0` (from 0.33.0) to support Qwen-Image models
## Usage in InvokeAI
### Node Graph Setup
1. Add a **"Main Model - Qwen-Image"** loader node
2. Select your Qwen-Image model from the dropdown
3. Select the Qwen2.5-VL model for text encoding
4. VAE field is optional (uses bundled VAE if left empty)
5. Connect to **Qwen-Image Text Encoder** node
6. Connect to **Qwen-Image Denoise** node
7. Add **VAE Decode** node to convert latents to images
### Model Selection
- **Main Model**: Select from models with base type "qwen-image"
- **Text Encoder**: Select Qwen2.5-VL-7B-Instruct
- **VAE**: Optional - leave empty to use bundled VAE, or select a custom VAE
## Troubleshooting
### VAE Not Showing Up
The Qwen-Image VAE is bundled with the main model. You don't need to download or select a separate VAE - just leave the VAE field empty to use the bundled one.
### Memory Issues
Qwen-Image is a large model (20B parameters) and Qwen2.5-VL is 7B parameters. Together they require significant resources:
**Memory Requirements:**
- **Minimum**: 24GB VRAM (with optimizations)
- **Recommended**: 32GB+ VRAM for smooth operation
- **System RAM**: 32GB+ recommended
**Optimization Tips:**
1. **Use bfloat16 precision**: Reduces memory by ~50%
```python
torch_dtype=torch.bfloat16
```
2. **Enable CPU offloading**: Move unused models to system RAM
- InvokeAI's model manager handles this automatically when configured
3. **Use quantized versions**:
- Try `diffusers/qwen-image-nf4` for 4-bit quantization
- Reduces memory usage by ~75% with minimal quality loss
4. **Adjust cache settings**: In InvokeAI settings:
- Reduce `ram_cache_size` if running out of system RAM
- Reduce `vram_cache_size` if getting CUDA OOM errors
5. **Load models sequentially**: Don't load all models at once
- The model manager now properly calculates sizes for better memory management
### Model Not Loading
- Ensure the model is in the correct directory structure
- Check that both Qwen-Image and Qwen2.5-VL models are downloaded
- Verify diffusers version is 0.35.0 or higher
## Future Enhancements
1. **Image Editing**: Support for Qwen-Image-Edit variant
2. **LoRA Support**: Fine-tuning capabilities
3. **Optimizations**: Quantization and speed improvements (Qwen-Image-Lightning)
4. **Advanced Features**: Image-to-image, inpainting, controlnet support
## Files Modified/Created
- `/invokeai/backend/model_manager/taxonomy.py` (modified)
- `/invokeai/backend/model_manager/config.py` (modified)
- `/invokeai/backend/model_manager/load/model_loaders/qwen_image.py` (created)
- `/invokeai/app/invocations/fields.py` (modified)
- `/invokeai/app/invocations/primitives.py` (modified)
- `/invokeai/app/invocations/qwen_image_text_encoder.py` (created)
- `/invokeai/app/invocations/qwen_image_denoise.py` (created)
- `/pyproject.toml` (modified)

View File

@@ -36,6 +36,9 @@ from pydantic_core import PydanticUndefined
from invokeai.app.invocations.fields import (
FieldKind,
Input,
InputFieldJSONSchemaExtra,
UIType,
migrate_model_ui_type,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -256,7 +259,9 @@ class BaseInvocation(ABC, BaseModel):
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
json_schema_extra=InputFieldJSONSchemaExtra(
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
).model_dump(exclude_none=True),
)
use_cache: bool = Field(
default=True,
@@ -445,6 +450,15 @@ with warnings.catch_warnings():
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
"""Checks if a value is a member of an enum class."""
try:
enum_class(value)
return True
except ValueError:
return False
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
if not isinstance(field.json_schema_extra, dict):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
)
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
field_kind = field.json_schema_extra.get("field_kind", None)
# must have a field_kind
if not isinstance(field_kind, FieldKind):
if not is_enum_member(field_kind, FieldKind):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
)
if field_kind is FieldKind.Input and (
if field_kind == FieldKind.Input.value and (
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
):
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
)
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
# node attribute fields *must* be in the reserved list
if (
field_kind is FieldKind.NodeAttribute
field_kind == FieldKind.NodeAttribute.value
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
)
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
if ui_type is not None:
# There are 3 cases where we may need to take action:
#
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
# deprecated and should be migrated to:
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
# - ui_model_type=[ModelType.Main]
#
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
#
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
# Image fields are now automatically detected based on the field's type annotation. In this case, we
# ignore ui_type and log a warning.
#
# The cases must be checked in this order to ensure proper handling.
# Easier to work with as an enum
ui_type = UIType(ui_type)
# The enum member values are not always the same as their names - we want to log the name so the user can
# easily review their code and see where the deprecated enum member is used.
human_readable_name = f"UIType.{ui_type.name}"
# Case 1: migratable deprecated value
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
if did_migrate:
logger.warning(
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
)
field.json_schema_extra.pop("ui_type")
# Case 2: mutually exclusive with new fields
elif (
ui_model_base is not None
or ui_model_type is not None
or ui_model_variant is not None
or ui_model_format is not None
):
logger.warning(
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
)
field.json_schema_extra.pop("ui_type")
# Case 3: deprecated value that is not migratable
elif ui_type.startswith("DEPRECATED_"):
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
field.json_schema_extra.pop("ui_type")
return None

View File

@@ -54,6 +54,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Internal Field Types
_Collection = "CollectionField"
_CollectionItem = "CollectionItemField"
_IsIntermediate = "IsIntermediate"
# endregion
# region DEPRECATED
@@ -91,7 +92,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
CollectionItem = "DEPRECATED_CollectionItem"
Enum = "DEPRECATED_Enum"
WorkflowField = "DEPRECATED_WorkflowField"
IsIntermediate = "DEPRECATED_IsIntermediate"
BoardField = "DEPRECATED_BoardField"
MetadataItem = "DEPRECATED_MetadataItem"
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
@@ -327,12 +327,6 @@ class CogView4ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class QwenImageConditioningField(BaseModel):
"""A conditioning tensor primitive value for Qwen-Image"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@@ -429,6 +423,7 @@ class InputFieldJSONSchemaExtra(BaseModel):
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
use_enum_values=True,
)
@@ -488,9 +483,114 @@ class OutputFieldJSONSchemaExtra(BaseModel):
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
use_enum_values=True,
)
def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, Any]) -> bool:
"""Migrate deprecated model-specifier ui_type values to new-style ui_model_[base|type|variant|format] in json_schema_extra."""
if not isinstance(ui_type, UIType):
ui_type = UIType(ui_type)
ui_model_type: list[ModelType] | None = None
ui_model_base: list[BaseModelType] | None = None
ui_model_format: list[ModelFormat] | None = None
ui_model_variant: list[ClipVariantType | ModelVariantType] | None = None
match ui_type:
case UIType.MainModel:
ui_model_base = [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
ui_model_type = [ModelType.Main]
case UIType.CogView4MainModel:
ui_model_base = [BaseModelType.CogView4]
ui_model_type = [ModelType.Main]
case UIType.FluxMainModel:
ui_model_base = [BaseModelType.Flux]
ui_model_type = [ModelType.Main]
case UIType.SD3MainModel:
ui_model_base = [BaseModelType.StableDiffusion3]
ui_model_type = [ModelType.Main]
case UIType.SDXLMainModel:
ui_model_base = [BaseModelType.StableDiffusionXL]
ui_model_type = [ModelType.Main]
case UIType.SDXLRefinerModel:
ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
ui_model_type = [ModelType.Main]
case UIType.VAEModel:
ui_model_type = [ModelType.VAE]
case UIType.FluxVAEModel:
ui_model_base = [BaseModelType.Flux]
ui_model_type = [ModelType.VAE]
case UIType.LoRAModel:
ui_model_type = [ModelType.LoRA]
case UIType.ControlNetModel:
ui_model_type = [ModelType.ControlNet]
case UIType.IPAdapterModel:
ui_model_type = [ModelType.IPAdapter]
case UIType.T2IAdapterModel:
ui_model_type = [ModelType.T2IAdapter]
case UIType.T5EncoderModel:
ui_model_type = [ModelType.T5Encoder]
case UIType.CLIPEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
case UIType.CLIPLEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
ui_model_variant = [ClipVariantType.L]
case UIType.CLIPGEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
ui_model_variant = [ClipVariantType.G]
case UIType.SpandrelImageToImageModel:
ui_model_type = [ModelType.SpandrelImageToImage]
case UIType.ControlLoRAModel:
ui_model_type = [ModelType.ControlLoRa]
case UIType.SigLipModel:
ui_model_type = [ModelType.SigLIP]
case UIType.FluxReduxModel:
ui_model_type = [ModelType.FluxRedux]
case UIType.LlavaOnevisionModel:
ui_model_type = [ModelType.LlavaOnevision]
case UIType.Imagen3Model:
ui_model_base = [BaseModelType.Imagen3]
ui_model_type = [ModelType.Main]
case UIType.Imagen4Model:
ui_model_base = [BaseModelType.Imagen4]
ui_model_type = [ModelType.Main]
case UIType.ChatGPT4oModel:
ui_model_base = [BaseModelType.ChatGPT4o]
ui_model_type = [ModelType.Main]
case UIType.Gemini2_5Model:
ui_model_base = [BaseModelType.Gemini2_5]
ui_model_type = [ModelType.Main]
case UIType.FluxKontextModel:
ui_model_base = [BaseModelType.FluxKontext]
ui_model_type = [ModelType.Main]
case UIType.Veo3Model:
ui_model_base = [BaseModelType.Veo3]
ui_model_type = [ModelType.Video]
case UIType.RunwayModel:
ui_model_base = [BaseModelType.Runway]
ui_model_type = [ModelType.Video]
case _:
pass
did_migrate = False
if ui_model_type is not None:
json_schema_extra["ui_model_type"] = [m.value for m in ui_model_type]
did_migrate = True
if ui_model_base is not None:
json_schema_extra["ui_model_base"] = [m.value for m in ui_model_base]
did_migrate = True
if ui_model_format is not None:
json_schema_extra["ui_model_format"] = [m.value for m in ui_model_format]
did_migrate = True
if ui_model_variant is not None:
json_schema_extra["ui_model_variant"] = [m.value for m in ui_model_variant]
did_migrate = True
return did_migrate
def InputField(
# copied from pydantic's Field
# TODO: Can we support default_factory?
@@ -581,93 +681,6 @@ def InputField(
field_kind=FieldKind.Input,
)
if ui_type is not None:
if (
ui_model_base is not None
or ui_model_type is not None
or ui_model_variant is not None
or ui_model_format is not None
):
logger.warning("InputField: Use either ui_type or ui_model_[base|type|variant|format]. Ignoring ui_type.")
# Map old-style UIType to new-style ui_model_[base|type|variant|format]
elif ui_type is UIType.MainModel:
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.CogView4MainModel:
json_schema_extra_.ui_model_base = [BaseModelType.CogView4]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.FluxMainModel:
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SD3MainModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusion3]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SDXLMainModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXL]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SDXLRefinerModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
json_schema_extra_.ui_model_type = [ModelType.Main]
# Think this UIType is unused...?
# elif ui_type is UIType.ONNXModel:
# json_schema_extra_.ui_model_base =
# json_schema_extra_.ui_model_type =
elif ui_type is UIType.VAEModel:
json_schema_extra_.ui_model_type = [ModelType.VAE]
elif ui_type is UIType.FluxVAEModel:
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
json_schema_extra_.ui_model_type = [ModelType.VAE]
elif ui_type is UIType.LoRAModel:
json_schema_extra_.ui_model_type = [ModelType.LoRA]
elif ui_type is UIType.ControlNetModel:
json_schema_extra_.ui_model_type = [ModelType.ControlNet]
elif ui_type is UIType.IPAdapterModel:
json_schema_extra_.ui_model_type = [ModelType.IPAdapter]
elif ui_type is UIType.T2IAdapterModel:
json_schema_extra_.ui_model_type = [ModelType.T2IAdapter]
elif ui_type is UIType.T5EncoderModel:
json_schema_extra_.ui_model_type = [ModelType.T5Encoder]
elif ui_type is UIType.CLIPEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
elif ui_type is UIType.CLIPLEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
json_schema_extra_.ui_model_variant = [ClipVariantType.L]
elif ui_type is UIType.CLIPGEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
json_schema_extra_.ui_model_variant = [ClipVariantType.G]
elif ui_type is UIType.SpandrelImageToImageModel:
json_schema_extra_.ui_model_type = [ModelType.SpandrelImageToImage]
elif ui_type is UIType.ControlLoRAModel:
json_schema_extra_.ui_model_type = [ModelType.ControlLoRa]
elif ui_type is UIType.SigLipModel:
json_schema_extra_.ui_model_type = [ModelType.SigLIP]
elif ui_type is UIType.FluxReduxModel:
json_schema_extra_.ui_model_type = [ModelType.FluxRedux]
elif ui_type is UIType.LlavaOnevisionModel:
json_schema_extra_.ui_model_type = [ModelType.LlavaOnevision]
elif ui_type is UIType.Imagen3Model:
json_schema_extra_.ui_model_base = [BaseModelType.Imagen3]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Imagen4Model:
json_schema_extra_.ui_model_base = [BaseModelType.Imagen4]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.ChatGPT4oModel:
json_schema_extra_.ui_model_base = [BaseModelType.ChatGPT4o]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Gemini2_5Model:
json_schema_extra_.ui_model_base = [BaseModelType.Gemini2_5]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.FluxKontextModel:
json_schema_extra_.ui_model_base = [BaseModelType.FluxKontext]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Veo3Model:
json_schema_extra_.ui_model_base = [BaseModelType.Veo3]
json_schema_extra_.ui_model_type = [ModelType.Video]
elif ui_type is UIType.RunwayModel:
json_schema_extra_.ui_model_base = [BaseModelType.Runway]
json_schema_extra_.ui_model_type = [ModelType.Video]
else:
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
@@ -696,6 +709,8 @@ def InputField(
json_schema_extra_.ui_model_format = ui_model_format
else:
json_schema_extra_.ui_model_format = [ui_model_format]
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's

View File

@@ -73,12 +73,6 @@ class GlmEncoderField(BaseModel):
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class Qwen2_5VLField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load Qwen2.5-VL tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load Qwen2.5-VL text encoder submodel")
loras: List[LoRAField] = Field(default_factory=list, description="LoRAs to apply on model loading")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')

View File

@@ -24,7 +24,6 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
QwenImageConditioningField,
SD3ConditioningField,
TensorField,
UIComponent,
@@ -487,17 +486,6 @@ class CogView4ConditioningOutput(BaseInvocationOutput):
return cls(conditioning=CogView4ConditioningField(conditioning_name=conditioning_name))
@invocation_output("qwen_image_conditioning_output")
class QwenImageConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a Qwen-Image conditioning tensor."""
conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput":
return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -1,150 +0,0 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Qwen-Image denoising invocation using diffusers pipeline."""
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
QwenImageConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_denoise",
title="Qwen-Image Denoise",
tags=["image", "qwen"],
category="image",
version="1.0.0",
)
class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run text-to-image generation with a Qwen-Image diffusion model."""
# Model components
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
# Text conditioning
positive_conditioning: QwenImageConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
# Generation parameters
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_inference_steps: int = InputField(
default=50, gt=0, description="Number of denoising steps."
)
guidance_scale: float = InputField(
default=7.5, gt=1.0, description="Classifier-free guidance scale."
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
"""Generate image using Qwen-Image pipeline."""
device = TorchDevice.choose_torch_device()
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load model components
with context.models.load(self.transformer.transformer) as transformer_info, \
context.models.load(self.vae.vae) as vae_info:
# Load conditioning data
conditioning_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
assert len(conditioning_data.conditionings) == 1
conditioning_info = conditioning_data.conditionings[0]
# Extract the prompt from conditioning
# The text encoder node stores both embeddings and the original prompt
prompt = getattr(conditioning_info, 'prompt', "A high-quality image")
# For now, we'll create a simplified pipeline
# In a full implementation, we'd properly load all components
try:
# Create the Qwen-Image pipeline with loaded components
# Note: This is a simplified approach. In production, we'd need to:
# 1. Load the text encoder from the conditioning
# 2. Properly initialize the pipeline with all components
# 3. Handle model configuration and dtype conversion
# For demonstration, we'll assume the models are loaded correctly
# and create a basic generation
transformer_model = transformer_info.model
vae_model = vae_info.model
# Move models to device
transformer_model = transformer_model.to(device, dtype=dtype)
vae_model = vae_model.to(device, dtype=dtype)
# Set up generator for reproducibility
generator = torch.Generator(device=device)
generator.manual_seed(self.seed)
# Create latents
latent_shape = (
1,
vae_model.config.latent_channels if hasattr(vae_model, 'config') else 4,
self.height // 8,
self.width // 8,
)
latents = torch.randn(latent_shape, generator=generator, device=device, dtype=dtype)
# Simple denoising loop (placeholder for actual implementation)
# In reality, we'd use the full QwenImagePipeline or implement the proper denoising
for _ in range(self.num_inference_steps):
# This is a placeholder - actual implementation would:
# 1. Apply noise scheduling
# 2. Use the transformer for denoising
# 3. Apply guidance scale
latents = latents * 0.99 # Placeholder denoising
# Decode latents to image
with torch.no_grad():
# Scale latents
latents = latents / vae_model.config.scaling_factor if hasattr(vae_model, 'config') else latents
# Decode
image = vae_model.decode(latents).sample if hasattr(vae_model, 'decode') else latents
# Convert to PIL Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if image.ndim == 4:
image = image[0]
# Convert to uint8
image = (image * 255).round().astype("uint8")
# Convert numpy array to PIL Image
from PIL import Image
pil_image = Image.fromarray(image)
except Exception as e:
context.logger.error(f"Error during Qwen-Image generation: {e}")
# Create a placeholder image on error
from PIL import Image
pil_image = Image.new('RGB', (self.width, self.height), color='gray')
# Save and return the generated image
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

View File

@@ -1,83 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import Input, InputField, OutputField
from invokeai.app.invocations.model import ModelIdentifierField, Qwen2_5VLField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("qwen_image_model_loader_output")
class QwenImageModelLoaderOutput(BaseInvocationOutput):
"""Qwen-Image base model loader output"""
transformer: TransformerField = OutputField(description="Qwen-Image transformer model", title="Transformer")
qwen2_5_vl: Qwen2_5VLField = OutputField(description="Qwen2.5-VL text encoder for Qwen-Image", title="Text Encoder")
vae: VAEField = OutputField(description="Qwen-Image VAE", title="VAE")
@invocation(
"qwen_image_model_loader",
title="Main Model - Qwen-Image",
tags=["model", "qwen-image"],
category="model",
version="1.0.0",
)
class QwenImageModelLoaderInvocation(BaseInvocation):
"""Loads a Qwen-Image base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description="Qwen-Image main model",
input=Input.Direct,
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.Main,
)
qwen2_5_vl_model: ModelIdentifierField = InputField(
description="Qwen2.5-VL vision-language model",
input=Input.Direct,
title="Qwen2.5-VL Model",
ui_model_base=BaseModelType.QwenImage,
# ui_model_type=ModelType.VL
)
vae_model: ModelIdentifierField | None = InputField(
description="VAE model for Qwen-Image",
title="VAE",
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.VAE,
default=None,
)
def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput:
# Validate that required models exist
for key in [self.model.key, self.qwen2_5_vl_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
# Validate optional VAE model if provided
if self.vae_model and not context.models.exists(self.vae_model.key):
raise ValueError(f"Unknown model: {self.vae_model.key}")
# Create submodel references
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
# Use provided VAE or extract from main model
if self.vae_model:
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
else:
# Use the VAE bundled with the Qwen-Image model
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
# For Qwen-Image, we use Qwen2.5-VL as the text encoder
tokenizer = self.qwen2_5_vl_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.qwen2_5_vl_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
return QwenImageModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen2_5_vl=Qwen2_5VLField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
vae=VAEField(vae=vae),
)

View File

@@ -1,79 +0,0 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Qwen-Image text encoding invocation."""
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import Input, InputField, UIComponent
from invokeai.app.invocations.model import Qwen2_5VLField
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@invocation(
"qwen_image_text_encoder",
title="Prompt - Qwen-Image",
tags=["prompt", "conditioning", "qwen"],
category="conditioning",
version="1.0.0",
)
class QwenImageTextEncoderInvocation(BaseInvocation):
"""Encodes a text prompt for Qwen-Image generation."""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen2_5_vl: Qwen2_5VLField = InputField(
title="Qwen2.5-VL",
description="Qwen2.5-VL vision-language model for text encoding",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
"""Encode the prompt using Qwen-Image's text encoder."""
# Load the text encoder info first to get the model
text_encoder_info = context.models.load(self.qwen2_5_vl.text_encoder)
# Load the Qwen2.5-VL tokenizer and text encoder with proper device management
with text_encoder_info.model_on_device() as (cached_weights, text_encoder), \
context.models.load(self.qwen2_5_vl.tokenizer) as tokenizer:
try:
# Tokenize the prompt
# Qwen2.5-VL supports much longer sequences than CLIP
text_inputs = tokenizer(
self.prompt,
padding="max_length",
max_length=1024, # Qwen2.5-VL supports much longer sequences
truncation=True,
return_tensors="pt",
)
# Encode the text (text_encoder is already on the correct device)
text_embeddings = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
# Create a simple conditioning info that stores the embeddings
# For now, we'll create a simple class to hold the data
class QwenImageConditioningInfo:
def __init__(self, text_embeds: torch.Tensor, prompt: str):
self.text_embeds = text_embeds
self.prompt = prompt
conditioning_info = QwenImageConditioningInfo(text_embeddings, self.prompt)
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)
except Exception as e:
context.logger.error(f"Error encoding Qwen-Image text: {e}")
# Fallback to simple text storage
class QwenImageConditioningInfo:
def __init__(self, prompt: str):
self.prompt = prompt
conditioning_info = QwenImageConditioningInfo(self.prompt)
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Coroutine, Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
@@ -22,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -135,6 +137,19 @@ class SessionQueueBase(ABC):
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items. Do not remove."""
pass
@abstractmethod
def list_all_queue_items(
self,

View File

@@ -34,6 +34,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -588,6 +589,59 @@ class SqliteSessionQueue(SessionQueueBase):
)
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,

View File

@@ -651,8 +651,6 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
}
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""

View File

@@ -1,108 +0,0 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Qwen-Image model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModelConfig, MainDiffusersConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Diffusers)
class QwenImageLoader(ModelLoader):
"""Class to load Qwen-Image models."""
def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> int:
"""Calculate the size of the Qwen-Image model on disk."""
if not isinstance(config, MainDiffusersConfig):
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
# For Qwen-Image, we need to calculate the size of the entire model or specific submodels
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant.value if config.repo_variant else None,
)
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, MainDiffusersConfig):
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
if config.base != BaseModelType.QwenImage:
raise ValueError("This loader only supports Qwen-Image models.")
model_path = Path(config.path)
if submodel_type is not None:
# Load individual submodel components with memory optimizations
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models import AutoencoderKLQwenImage
# Force bfloat16 for memory efficiency if not already set
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
# Load only the specific submodel, not the entire pipeline
if submodel_type == SubModelType.VAE:
# Load VAE directly from subfolder
vae_path = model_path / "vae"
if vae_path.exists():
return AutoencoderKLQwenImage.from_pretrained(
vae_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
elif submodel_type == SubModelType.Transformer:
# Load transformer directly from subfolder
transformer_path = model_path / "transformer"
if transformer_path.exists():
return QwenImageTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
# Fallback to loading full pipeline if direct loading fails
pipeline = DiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch_dtype,
variant=config.repo_variant.value if config.repo_variant else None,
low_cpu_mem_usage=True,
)
# Return the specific submodel
if hasattr(pipeline, submodel_type.value):
return getattr(pipeline, submodel_type.value)
else:
raise ValueError(f"Submodel {submodel_type} not found in Qwen-Image pipeline.")
else:
# Load the full pipeline with memory optimizations
import torch
# Force bfloat16 for memory efficiency if not already set
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch_dtype,
variant=config.repo_variant.value if config.repo_variant else None,
low_cpu_mem_usage=True, # Important for reducing memory during loading
)
return pipeline

View File

@@ -33,7 +33,6 @@ class BaseModelType(str, Enum):
FluxKontext = "flux-kontext"
Veo3 = "veo3"
Runway = "runway"
QwenImage = "qwen-image"
class ModelType(str, Enum):

View File

@@ -90,7 +90,7 @@ export const RasterLayerAdjustmentsPanel = memo(() => {
}
const rect = adapter.transformer.getRelativeRect();
try {
await adapter.renderer.rasterize({ rect, replaceObjects: true });
await adapter.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } });
// Clear adjustments after baking
dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null }));
} catch {

View File

@@ -16,7 +16,6 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
cogview4: 'red',
'qwen-image': 'cyan',
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',

View File

@@ -82,7 +82,6 @@ export const zBaseModelType = z.enum([
'sdxl-refiner',
'flux',
'cogview4',
'qwen-image',
'imagen3',
'imagen4',
'chatgpt-4o',
@@ -99,7 +98,6 @@ export const zMainModelBase = z.enum([
'sdxl',
'flux',
'cogview4',
'qwen-image',
'imagen3',
'imagen4',
'chatgpt-4o',

View File

@@ -5,7 +5,7 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic
import { selectCanvasMetadata } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { type ImageField, zImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getOriginalAndScaledSizesForOtherModes,
@@ -49,9 +49,7 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
reference_images = [];
for (const entity of validRefImages) {
assert(entity.config.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.config.image.crop?.image.image_name ?? entity.config.image.original.image.image_name,
});
reference_images.push(zImageField.parse(entity.config.image.crop?.image ?? entity.config.image.original.image));
}
}

View File

@@ -164,7 +164,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
const kontextImagePrep = g.addNode({
id: getPrefixedId('flux_kontext_image_prep'),
type: 'flux_kontext_image_prep',
images: [zImageField.parse(config.image)],
images: [zImageField.parse(config.image?.crop?.image ?? config.image?.original.image)],
});
const kontextConditioning = g.addNode({
type: 'flux_kontext',

View File

@@ -60,9 +60,7 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
model: zModelIdentifierField.parse(model),
aspect_ratio: aspectRatio.id,
prompt_upsampling: true,
input_image: {
image_name: firstImage.crop?.image.image_name ?? firstImage.original.image.image_name,
},
input_image: zImageField.parse(firstImage.crop?.image ?? firstImage.original.image),
...selectCanvasOutputFields(state),
});
} else {
@@ -70,7 +68,9 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
const kontextConcatenator = g.addNode({
id: getPrefixedId('flux_kontext_image_prep'),
type: 'flux_kontext_image_prep',
images: validRefImages.map(({ config }) => zImageField.parse(config.image)),
images: validRefImages.map(({ config }) =>
zImageField.parse(config.image?.crop?.image ?? config.image?.original.image)
),
});
fluxKontextImage = g.addNode({

View File

@@ -4,7 +4,7 @@ import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { isGemini2_5ReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { ImageField } from 'features/nodes/types/common';
import { type ImageField, zImageField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
@@ -44,9 +44,7 @@ export const buildGemini2_5Graph = (arg: GraphBuilderArg): GraphBuilderReturn =>
reference_images = [];
for (const entity of validRefImages) {
assert(entity.config.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.config.image.crop?.image.image_name ?? entity.config.image.original.image.image_name,
});
reference_images.push(zImageField.parse(entity.config.image.crop?.image ?? entity.config.image.original.image));
}
}

View File

@@ -13,7 +13,6 @@ export const MODEL_TYPE_MAP: Record<BaseModelType, string> = {
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'FLUX',
cogview4: 'CogView4',
'qwen-image': 'Qwen-Image',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
@@ -35,7 +34,6 @@ export const MODEL_TYPE_SHORT_MAP: Record<BaseModelType, string> = {
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
cogview4: 'CogView4',
'qwen-image': 'Qwen',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',

File diff suppressed because one or more lines are too long

View File

@@ -36,7 +36,7 @@ dependencies = [
"accelerate",
"bitsandbytes; sys_platform!='darwin'",
"compel==2.1.1",
"diffusers[torch]==0.35.0",
"diffusers[torch]==0.33.0",
"gguf",
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
"numpy<2.0.0",

View File

@@ -1,26 +0,0 @@
# Qwen-Image Test Configuration with Memory Optimizations
# This config helps test Qwen-Image with limited VRAM
# Model Cache Settings - Adjust based on your system
# These settings enable CPU offloading for large models
Model:
# Reduce VRAM cache to force CPU offloading
vram_cache_size: 8.0 # GB - Keep only essential models in VRAM
# Increase RAM cache for CPU offloading
ram_cache_size: 32.0 # GB - Adjust based on available system RAM
# Enable sequential offloading
sequential_offload: true
# Use bfloat16 by default for all models
precision: bfloat16
# Recommended workflow for testing:
# 1. Load only the Qwen-Image model first (not Qwen2.5-VL)
# 2. Use a simple text prompt without the text encoder
# 3. Test with smaller image sizes (512x512) initially
# Alternative: Use quantized models
# Download: huggingface-cli download diffusers/qwen-image-nf4
# This reduces memory usage by ~75%

View File

@@ -1,26 +0,0 @@
#!/bin/bash
# Run InvokeAI with optimized settings for Qwen-Image models
echo "Starting InvokeAI with Qwen-Image memory optimizations..."
echo "----------------------------------------"
echo "Recommendations for 24GB VRAM systems:"
echo "1. Set VRAM cache to 8-10GB in InvokeAI settings"
echo "2. Set RAM cache to 20-30GB (based on available system RAM)"
echo "3. Use bfloat16 precision (default in our loader)"
echo "----------------------------------------"
# Set environment variables for better memory management
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:512"
export CUDA_LAUNCH_BLOCKING=0
# Optional: Limit CPU threads to prevent memory thrashing
export OMP_NUM_THREADS=8
# Run InvokeAI with your root directory
invokeai-web --root ~/invokeai/ \
--precision bfloat16 \
--max_cache_size 8.0 \
--max_vram_cache_size 8.0
# Alternative: Use with config file
# invokeai-web --root ~/invokeai/ --config qwen_test_config.yaml