mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager
This commit is contained in:
committed by
psychedelicious
parent
b0835db47d
commit
35e8a33dfd
@@ -812,7 +812,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.Tensor)
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
@@ -18,7 +18,7 @@ from .baseinvocation import (
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
@@ -110,22 +110,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel=SubModelType.UNet,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
@@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel=SubModelType.Vae,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -188,7 +188,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -198,7 +198,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -271,7 +271,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -281,7 +281,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -291,7 +291,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -43,29 +43,29 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_records.exists(model_key):
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.UNet,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
@@ -73,11 +73,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
@@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Vae,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -112,29 +112,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_records.exists(model_key):
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.UNet,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
@@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel=SubModelType.Vae,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user