mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 12:08:23 -05:00
Compare commits
4 Commits
psyche/fea
...
maryhipp/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bd1f4a4f4 | ||
|
|
28864f6d7f | ||
|
|
c63fe5e9bb | ||
|
|
674f530501 |
@@ -751,7 +751,7 @@ async def convert_model(
|
||||
|
||||
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
|
||||
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
|
||||
converted_model = loader.load_model(model_config)
|
||||
converted_model = loader.load_model(model_config, queue_id="default")
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
|
||||
@@ -31,6 +31,7 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadEventBase,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
@@ -53,6 +54,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
class ModelLoadSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io model loading room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
queue_id: str
|
||||
|
||||
|
||||
QUEUE_EVENTS = {
|
||||
InvocationStartedEvent,
|
||||
InvocationProgressEvent,
|
||||
@@ -69,8 +77,6 @@ MODEL_EVENTS = {
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
@@ -79,6 +85,11 @@ MODEL_EVENTS = {
|
||||
ModelInstallErrorEvent,
|
||||
}
|
||||
|
||||
MODEL_LOAD_EVENTS = {
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
}
|
||||
|
||||
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
|
||||
|
||||
|
||||
@@ -101,6 +112,7 @@ class SocketIO:
|
||||
|
||||
register_events(QUEUE_EVENTS, self._handle_queue_event)
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(MODEL_LOAD_EVENTS, self._handle_model_load_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
@@ -115,9 +127,18 @@ class SocketIO:
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_sub_model_load(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_unsub_model_load(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
|
||||
async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -63,12 +63,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer, queue_id=context.util.get_queue_id())
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder, queue_id=context.util.get_queue_id())
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, queue_id=context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -137,8 +137,8 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer, queue_id=context.util.get_queue_id())
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder, queue_id=context.util.get_queue_id())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
|
||||
@@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size],
|
||||
queue_id=self._context.util.get_queue_id(),
|
||||
loader=load_depth_anything,
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
@@ -60,7 +60,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
|
||||
@@ -124,7 +124,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = blur_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -88,7 +88,7 @@ def get_scheduler(
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
control_model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
control_image_field = control_info.image
|
||||
@@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
@@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
image_prompts = []
|
||||
for single_ip_adapter in ip_adapters:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||
with context.models.load(
|
||||
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
|
||||
) as ip_adapter_model:
|
||||
assert isinstance(ip_adapter_model, IPAdapter)
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
image_encoder_model_info = context.models.load(
|
||||
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
|
||||
)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
||||
ip_adapters, image_prompts, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
|
||||
mask_field = single_ip_adapter.mask
|
||||
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||
@@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
t2i_adapter_loaded_model = context.models.load(
|
||||
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
|
||||
)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@@ -926,7 +938,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
@@ -989,13 +1001,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
|
||||
@@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
|
||||
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
|
||||
)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
|
||||
@@ -29,10 +29,10 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
|
||||
@@ -183,7 +183,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
@@ -468,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
||||
controlnet_models = [
|
||||
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
|
||||
]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
@@ -479,7 +481,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
|
||||
controlnet_conds.append(
|
||||
InstantXControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
@@ -590,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
pos_images.append(pos_image)
|
||||
neg_images.append(neg_image)
|
||||
|
||||
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
|
||||
with context.models.load(
|
||||
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
|
||||
) as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
||||
@@ -620,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
|
||||
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
|
||||
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
|
||||
if ip_adapter_field.mask is not None:
|
||||
@@ -649,7 +655,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -77,8 +77,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -118,7 +118,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -52,7 +52,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -54,7 +54,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=GroundingDinoInvocation._load_grounding_dino,
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
||||
@@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, ControlNetHED_Apache2)
|
||||
|
||||
@@ -111,7 +111,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -36,7 +36,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
def infill(self, image: Image.Image, queue_id: str) -> Image.Image:
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +56,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
infilled_image = self.infill(input_image, context.util.get_queue_id())
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
@@ -74,7 +74,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
@@ -93,7 +93,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
return output.infilled
|
||||
|
||||
@@ -107,7 +107,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
width = int(image.width / self.downscale)
|
||||
@@ -131,9 +131,10 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
queue_id=queue_id,
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
@@ -144,7 +145,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
@@ -166,5 +167,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, queue_id: str):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
|
||||
@@ -57,7 +57,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
@@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartEdgeDetector.get_model_url(self.coarse)
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, Generator)
|
||||
|
||||
@@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartAnimeEdgeDetector.get_model_url()
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, UnetGenerator)
|
||||
|
||||
@@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, MobileV2_MLSD_Large)
|
||||
|
||||
@@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, NNET)
|
||||
|
||||
@@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, PiDiNet)
|
||||
|
||||
@@ -147,7 +147,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
|
||||
@@ -44,7 +44,7 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
|
||||
@@ -86,8 +86,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -127,8 +127,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
@@ -193,7 +193,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -125,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=SegmentAnythingInvocation._load_sam_model,
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
|
||||
@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
|
||||
|
||||
def step_callback(step: int, total_steps: int) -> None:
|
||||
context.util.signal_progress(
|
||||
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
|
||||
@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
|
||||
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
|
||||
@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
raise ValueError(msg)
|
||||
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
|
||||
)
|
||||
|
||||
with loadnet as loadnet_model:
|
||||
|
||||
@@ -131,15 +131,17 @@ class EventServiceBase:
|
||||
|
||||
# region Model loading
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
def emit_model_load_started(
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -383,12 +383,14 @@ class DownloadErrorEvent(DownloadEventBase):
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
class ModelLoadEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
class ModelLoadStartedEvent(ModelLoadEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
@@ -397,12 +399,14 @@ class ModelLoadStartedEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
class ModelLoadCompleteEvent(ModelLoadEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
@@ -411,8 +415,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for model events"""
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
|
||||
@@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
@@ -29,7 +31,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
@@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
@@ -60,7 +62,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
self._invoker.services.events.emit_model_load_started(model_config, queue_id, submodel_type)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@@ -70,12 +72,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, queue_id, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
|
||||
@@ -351,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
identifier: Union[str, "ModelIdentifierField"],
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model.
|
||||
|
||||
@@ -368,14 +371,19 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, queue_id, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
name: str,
|
||||
base: BaseModelType,
|
||||
type: ModelType,
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model by its attributes.
|
||||
|
||||
@@ -397,7 +405,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
return self._services.model_manager.load.load_model(configs[0], queue_id, submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Get a model's config.
|
||||
@@ -472,6 +480,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_local_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@@ -489,11 +498,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@@ -514,7 +526,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
@@ -535,6 +549,14 @@ class UtilInterface(InvocationContextInterface):
|
||||
super().__init__(services, data)
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
def get_queue_id(self) -> str:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._data.queue_item.queue_id
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def generate_ti_list(
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
loaded_model = context.models.load(name_or_key, queue_id=context.util.get_queue_id())
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
@@ -30,7 +30,7 @@ def generate_ti_list(
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion, queue_id=context.util.get_queue_id()
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
||||
@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=unet,
|
||||
|
||||
@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
|
||||
@callback(ExtensionCallbackType.SETUP)
|
||||
def setup(self, ctx: DenoiseContext):
|
||||
t2i_model: T2IAdapter
|
||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
|
||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||
|
||||
self._adapter_state = self._run_model(
|
||||
|
||||
@@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
@@ -60,16 +60,20 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
|
||||
@pytest.mark.skip(reason="This requires a test model to load")
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory, mock_context.util.get_queue_id())
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_1 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
loaded_model_2 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
Reference in New Issue
Block a user