fix: rebase borkage

This commit is contained in:
psychedelicious
2025-09-18 14:19:31 +10:00
parent 14b335d42f
commit ac40cd47d4
7 changed files with 43 additions and 182 deletions

View File

@@ -1,10 +1,8 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Qwen-Image denoising invocation using diffusers pipeline."""
from typing import Optional
import torch
from diffusers.pipelines import QwenImagePipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
@@ -37,18 +35,18 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
# Text conditioning
positive_conditioning: QwenImageConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
# Generation parameters
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
@@ -63,23 +61,23 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
"""Generate image using Qwen-Image pipeline."""
device = TorchDevice.choose_torch_device()
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load model components
with context.models.load(self.transformer.transformer) as transformer_info, \
context.models.load(self.vae.vae) as vae_info:
# Load conditioning data
conditioning_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
assert len(conditioning_data.conditionings) == 1
conditioning_info = conditioning_data.conditionings[0]
# Extract the prompt from conditioning
# The text encoder node stores both embeddings and the original prompt
prompt = getattr(conditioning_info, 'prompt', "A high-quality image")
# For now, we'll create a simplified pipeline
# In a full implementation, we'd properly load all components
try:
@@ -88,20 +86,20 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# 1. Load the text encoder from the conditioning
# 2. Properly initialize the pipeline with all components
# 3. Handle model configuration and dtype conversion
# For demonstration, we'll assume the models are loaded correctly
# and create a basic generation
transformer_model = transformer_info.model
vae_model = vae_info.model
# Move models to device
transformer_model = transformer_model.to(device, dtype=dtype)
vae_model = vae_model.to(device, dtype=dtype)
# Set up generator for reproducibility
generator = torch.Generator(device=device)
generator.manual_seed(self.seed)
# Create latents
latent_shape = (
1,
@@ -110,7 +108,7 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
self.width // 8,
)
latents = torch.randn(latent_shape, generator=generator, device=device, dtype=dtype)
# Simple denoising loop (placeholder for actual implementation)
# In reality, we'd use the full QwenImagePipeline or implement the proper denoising
for _ in range(self.num_inference_steps):
@@ -119,34 +117,34 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# 2. Use the transformer for denoising
# 3. Apply guidance scale
latents = latents * 0.99 # Placeholder denoising
# Decode latents to image
with torch.no_grad():
# Scale latents
latents = latents / vae_model.config.scaling_factor if hasattr(vae_model, 'config') else latents
# Decode
image = vae_model.decode(latents).sample if hasattr(vae_model, 'decode') else latents
# Convert to PIL Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if image.ndim == 4:
image = image[0]
# Convert to uint8
image = (image * 255).round().astype("uint8")
# Convert numpy array to PIL Image
from PIL import Image
pil_image = Image.fromarray(image)
except Exception as e:
context.logger.error(f"Error during Qwen-Image generation: {e}")
# Create a placeholder image on error
from PIL import Image
pil_image = Image.new('RGB', (self.width, self.height), color='gray')
# Save and return the generated image
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
return ImageOutput.build(image_dto)

View File

@@ -4,7 +4,7 @@
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.fields import Input, InputField, UIComponent
from invokeai.app.invocations.model import Qwen2_5VLField
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -31,14 +31,14 @@ class QwenImageTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
"""Encode the prompt using Qwen-Image's text encoder."""
# Load the text encoder info first to get the model
text_encoder_info = context.models.load(self.qwen2_5_vl.text_encoder)
# Load the Qwen2.5-VL tokenizer and text encoder with proper device management
with text_encoder_info.model_on_device() as (cached_weights, text_encoder), \
context.models.load(self.qwen2_5_vl.tokenizer) as tokenizer:
try:
# Tokenize the prompt
# Qwen2.5-VL supports much longer sequences than CLIP
@@ -49,31 +49,31 @@ class QwenImageTextEncoderInvocation(BaseInvocation):
truncation=True,
return_tensors="pt",
)
# Encode the text (text_encoder is already on the correct device)
text_embeddings = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
# Create a simple conditioning info that stores the embeddings
# For now, we'll create a simple class to hold the data
class QwenImageConditioningInfo:
def __init__(self, text_embeds: torch.Tensor, prompt: str):
self.text_embeds = text_embeds
self.prompt = prompt
conditioning_info = QwenImageConditioningInfo(text_embeddings, self.prompt)
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)
except Exception as e:
context.logger.error(f"Error encoding Qwen-Image text: {e}")
# Fallback to simple text storage
class QwenImageConditioningInfo:
def __init__(self, prompt: str):
self.prompt = prompt
conditioning_info = QwenImageConditioningInfo(self.prompt)
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)
return QwenImageConditioningOutput.build(conditioning_name)

View File

@@ -29,7 +29,7 @@ class QwenImageLoader(ModelLoader):
"""Calculate the size of the Qwen-Image model on disk."""
if not isinstance(config, MainDiffusersConfig):
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
# For Qwen-Image, we need to calculate the size of the entire model or specific submodels
return calc_model_size_by_fs(
model_path=model_path,
@@ -44,21 +44,21 @@ class QwenImageLoader(ModelLoader):
) -> AnyModel:
if not isinstance(config, MainDiffusersConfig):
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
if config.base != BaseModelType.QwenImage:
raise ValueError("This loader only supports Qwen-Image models.")
model_path = Path(config.path)
if submodel_type is not None:
# Load individual submodel components with memory optimizations
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models import AutoencoderKLQwenImage
# Force bfloat16 for memory efficiency if not already set
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
# Load only the specific submodel, not the entire pipeline
if submodel_type == SubModelType.VAE:
# Load VAE directly from subfolder
@@ -78,7 +78,7 @@ class QwenImageLoader(ModelLoader):
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
# Fallback to loading full pipeline if direct loading fails
pipeline = DiffusionPipeline.from_pretrained(
model_path,
@@ -86,7 +86,7 @@ class QwenImageLoader(ModelLoader):
variant=config.repo_variant.value if config.repo_variant else None,
low_cpu_mem_usage=True,
)
# Return the specific submodel
if hasattr(pipeline, submodel_type.value):
return getattr(pipeline, submodel_type.value)
@@ -95,14 +95,14 @@ class QwenImageLoader(ModelLoader):
else:
# Load the full pipeline with memory optimizations
import torch
# Force bfloat16 for memory efficiency if not already set
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch_dtype,
variant=config.repo_variant.value if config.repo_variant else None,
low_cpu_mem_usage=True, # Important for reducing memory during loading
)
return pipeline
return pipeline

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Qwen2_5VLModelFieldInputInstance, Qwen2_5VLModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<Qwen2_5VLModelFieldInputInstance, Qwen2_5VLModelFieldInputTemplate>;
const Qwen2_5VLModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
// For now, using main models as Qwen2.5-VL is a main model that acts as text encoder
// In the future, we might want to create a specific hook for Qwen2.5-VL models
const [modelConfigs, { isLoading }] = useMainModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Qwen2_5VLModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { QwenImageMainModelFieldInputInstance, QwenImageMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useQwenImageModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<QwenImageMainModelFieldInputInstance, QwenImageMainModelFieldInputTemplate>;
const QwenImageMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useQwenImageModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(QwenImageMainModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { QwenImageVAEModelFieldInputInstance, QwenImageVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<QwenImageVAEModelFieldInputInstance, QwenImageVAEModelFieldInputTemplate>;
const QwenImageVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldVaeModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(QwenImageVAEModelFieldInputComponent);

View File

@@ -16,9 +16,6 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,
ImageGeneratorField: undefined,
QwenImageMainModelField: undefined,
QwenImageVAEModelField: undefined,
Qwen2_5VLModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {