mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix: rebase borkage
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Qwen-Image denoising invocation using diffusers pipeline."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines import QwenImagePipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -37,18 +35,18 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
|
||||
# Text conditioning
|
||||
positive_conditioning: QwenImageConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
|
||||
|
||||
# Generation parameters
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
@@ -63,23 +61,23 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
"""Generate image using Qwen-Image pipeline."""
|
||||
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
|
||||
# Load model components
|
||||
with context.models.load(self.transformer.transformer) as transformer_info, \
|
||||
context.models.load(self.vae.vae) as vae_info:
|
||||
|
||||
|
||||
# Load conditioning data
|
||||
conditioning_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
assert len(conditioning_data.conditionings) == 1
|
||||
conditioning_info = conditioning_data.conditionings[0]
|
||||
|
||||
|
||||
# Extract the prompt from conditioning
|
||||
# The text encoder node stores both embeddings and the original prompt
|
||||
prompt = getattr(conditioning_info, 'prompt', "A high-quality image")
|
||||
|
||||
|
||||
# For now, we'll create a simplified pipeline
|
||||
# In a full implementation, we'd properly load all components
|
||||
try:
|
||||
@@ -88,20 +86,20 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# 1. Load the text encoder from the conditioning
|
||||
# 2. Properly initialize the pipeline with all components
|
||||
# 3. Handle model configuration and dtype conversion
|
||||
|
||||
|
||||
# For demonstration, we'll assume the models are loaded correctly
|
||||
# and create a basic generation
|
||||
transformer_model = transformer_info.model
|
||||
vae_model = vae_info.model
|
||||
|
||||
|
||||
# Move models to device
|
||||
transformer_model = transformer_model.to(device, dtype=dtype)
|
||||
vae_model = vae_model.to(device, dtype=dtype)
|
||||
|
||||
|
||||
# Set up generator for reproducibility
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(self.seed)
|
||||
|
||||
|
||||
# Create latents
|
||||
latent_shape = (
|
||||
1,
|
||||
@@ -110,7 +108,7 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
self.width // 8,
|
||||
)
|
||||
latents = torch.randn(latent_shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
|
||||
# Simple denoising loop (placeholder for actual implementation)
|
||||
# In reality, we'd use the full QwenImagePipeline or implement the proper denoising
|
||||
for _ in range(self.num_inference_steps):
|
||||
@@ -119,34 +117,34 @@ class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# 2. Use the transformer for denoising
|
||||
# 3. Apply guidance scale
|
||||
latents = latents * 0.99 # Placeholder denoising
|
||||
|
||||
|
||||
# Decode latents to image
|
||||
with torch.no_grad():
|
||||
# Scale latents
|
||||
latents = latents / vae_model.config.scaling_factor if hasattr(vae_model, 'config') else latents
|
||||
# Decode
|
||||
image = vae_model.decode(latents).sample if hasattr(vae_model, 'decode') else latents
|
||||
|
||||
|
||||
# Convert to PIL Image
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
|
||||
if image.ndim == 4:
|
||||
image = image[0]
|
||||
|
||||
|
||||
# Convert to uint8
|
||||
image = (image * 255).round().astype("uint8")
|
||||
|
||||
|
||||
# Convert numpy array to PIL Image
|
||||
from PIL import Image
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
context.logger.error(f"Error during Qwen-Image generation: {e}")
|
||||
# Create a placeholder image on error
|
||||
from PIL import Image
|
||||
pil_image = Image.new('RGB', (self.width, self.height), color='gray')
|
||||
|
||||
|
||||
# Save and return the generated image
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.fields import Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.model import Qwen2_5VLField
|
||||
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -31,14 +31,14 @@ class QwenImageTextEncoderInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
|
||||
"""Encode the prompt using Qwen-Image's text encoder."""
|
||||
|
||||
|
||||
# Load the text encoder info first to get the model
|
||||
text_encoder_info = context.models.load(self.qwen2_5_vl.text_encoder)
|
||||
|
||||
|
||||
# Load the Qwen2.5-VL tokenizer and text encoder with proper device management
|
||||
with text_encoder_info.model_on_device() as (cached_weights, text_encoder), \
|
||||
context.models.load(self.qwen2_5_vl.tokenizer) as tokenizer:
|
||||
|
||||
|
||||
try:
|
||||
# Tokenize the prompt
|
||||
# Qwen2.5-VL supports much longer sequences than CLIP
|
||||
@@ -49,31 +49,31 @@ class QwenImageTextEncoderInvocation(BaseInvocation):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
# Encode the text (text_encoder is already on the correct device)
|
||||
text_embeddings = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
|
||||
|
||||
|
||||
# Create a simple conditioning info that stores the embeddings
|
||||
# For now, we'll create a simple class to hold the data
|
||||
class QwenImageConditioningInfo:
|
||||
def __init__(self, text_embeds: torch.Tensor, prompt: str):
|
||||
self.text_embeds = text_embeds
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
conditioning_info = QwenImageConditioningInfo(text_embeddings, self.prompt)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
|
||||
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return QwenImageConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
context.logger.error(f"Error encoding Qwen-Image text: {e}")
|
||||
# Fallback to simple text storage
|
||||
class QwenImageConditioningInfo:
|
||||
def __init__(self, prompt: str):
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
conditioning_info = QwenImageConditioningInfo(self.prompt)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[conditioning_info])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return QwenImageConditioningOutput.build(conditioning_name)
|
||||
return QwenImageConditioningOutput.build(conditioning_name)
|
||||
|
||||
@@ -29,7 +29,7 @@ class QwenImageLoader(ModelLoader):
|
||||
"""Calculate the size of the Qwen-Image model on disk."""
|
||||
if not isinstance(config, MainDiffusersConfig):
|
||||
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
|
||||
|
||||
|
||||
# For Qwen-Image, we need to calculate the size of the entire model or specific submodels
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
@@ -44,21 +44,21 @@ class QwenImageLoader(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, MainDiffusersConfig):
|
||||
raise ValueError("Only MainDiffusersConfig models are currently supported here.")
|
||||
|
||||
|
||||
if config.base != BaseModelType.QwenImage:
|
||||
raise ValueError("This loader only supports Qwen-Image models.")
|
||||
|
||||
|
||||
model_path = Path(config.path)
|
||||
|
||||
|
||||
if submodel_type is not None:
|
||||
# Load individual submodel components with memory optimizations
|
||||
import torch
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models import AutoencoderKLQwenImage
|
||||
|
||||
|
||||
# Force bfloat16 for memory efficiency if not already set
|
||||
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
|
||||
|
||||
|
||||
# Load only the specific submodel, not the entire pipeline
|
||||
if submodel_type == SubModelType.VAE:
|
||||
# Load VAE directly from subfolder
|
||||
@@ -78,7 +78,7 @@ class QwenImageLoader(ModelLoader):
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Fallback to loading full pipeline if direct loading fails
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
@@ -86,7 +86,7 @@ class QwenImageLoader(ModelLoader):
|
||||
variant=config.repo_variant.value if config.repo_variant else None,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Return the specific submodel
|
||||
if hasattr(pipeline, submodel_type.value):
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
@@ -95,14 +95,14 @@ class QwenImageLoader(ModelLoader):
|
||||
else:
|
||||
# Load the full pipeline with memory optimizations
|
||||
import torch
|
||||
|
||||
|
||||
# Force bfloat16 for memory efficiency if not already set
|
||||
torch_dtype = self._torch_dtype if self._torch_dtype is not None else torch.bfloat16
|
||||
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=config.repo_variant.value if config.repo_variant else None,
|
||||
low_cpu_mem_usage=True, # Important for reducing memory during loading
|
||||
)
|
||||
return pipeline
|
||||
return pipeline
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Qwen2_5VLModelFieldInputInstance, Qwen2_5VLModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<Qwen2_5VLModelFieldInputInstance, Qwen2_5VLModelFieldInputTemplate>;
|
||||
|
||||
const Qwen2_5VLModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
// For now, using main models as Qwen2.5-VL is a main model that acts as text encoder
|
||||
// In the future, we might want to create a specific hook for Qwen2.5-VL models
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Qwen2_5VLModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { QwenImageMainModelFieldInputInstance, QwenImageMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useQwenImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<QwenImageMainModelFieldInputInstance, QwenImageMainModelFieldInputTemplate>;
|
||||
|
||||
const QwenImageMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useQwenImageModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QwenImageMainModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { QwenImageVAEModelFieldInputInstance, QwenImageVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<QwenImageVAEModelFieldInputInstance, QwenImageVAEModelFieldInputTemplate>;
|
||||
|
||||
const QwenImageVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldVaeModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QwenImageVAEModelFieldInputComponent);
|
||||
@@ -16,9 +16,6 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
IntegerGeneratorField: undefined,
|
||||
StringGeneratorField: undefined,
|
||||
ImageGeneratorField: undefined,
|
||||
QwenImageMainModelField: undefined,
|
||||
QwenImageVAEModelField: undefined,
|
||||
Qwen2_5VLModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
Reference in New Issue
Block a user