add queue_id to all model load invocations

This commit is contained in:
Mary Hipp
2024-11-06 16:19:34 -05:00
parent 674f530501
commit c63fe5e9bb
20 changed files with 44 additions and 44 deletions

View File

@@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)

View File

@@ -649,7 +649,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
source=DEPTH_ANYTHING_MODELS[self.model_size], queue_id=self._context.util.get_queue_id(), loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)

View File

@@ -60,7 +60,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)

View File

@@ -124,7 +124,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -88,7 +88,7 @@ def get_scheduler(
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -435,7 +435,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
control_model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
@@ -492,7 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model))
model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
ext_manager.add_extension(
ControlNetExt(
model=model,
@@ -545,9 +545,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
with context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model, context.util.get_queue_id())
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -581,7 +581,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()))
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
@@ -621,7 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id())
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -926,7 +926,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
@@ -989,13 +989,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,

View File

@@ -183,7 +183,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
@@ -468,7 +468,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
controlnet_models = [context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
@@ -479,7 +479,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
@@ -590,7 +590,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
with context.models.load(ip_adapter_field.image_encoder_model, context.util.get_queue_id()) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
@@ -620,7 +620,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id()))
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:
@@ -649,7 +649,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -77,8 +77,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -118,7 +118,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -52,7 +52,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()

View File

@@ -54,7 +54,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -111,7 +111,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -57,7 +57,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))

View File

@@ -147,7 +147,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.

View File

@@ -44,7 +44,7 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL))

View File

@@ -86,8 +86,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -127,8 +127,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -193,7 +193,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -125,7 +125,7 @@ class SegmentAnythingInvocation(BaseInvocation):
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],queue_id=context.util.get_queue_id(), loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

View File

@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.

View File

@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
with (
ExitStack() as exit_stack,

View File

@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(msg)
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
)
with loadnet as loadnet_model:

View File

@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,

View File

@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
@callback(ExtensionCallbackType.SETUP)
def setup(self, ctx: DenoiseContext):
t2i_model: T2IAdapter
with self._node_context.models.load(self._model_id) as t2i_model:
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
self._adapter_state = self._run_model(