mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-31 10:47:55 -05:00
56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
from invokeai.app.invocations.baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
Classification,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
|
from invokeai.app.invocations.model import (
|
|
GlmEncoderField,
|
|
ModelIdentifierField,
|
|
TransformerField,
|
|
VAEField,
|
|
)
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.model_manager.config import SubModelType
|
|
|
|
|
|
@invocation_output("cogview4_model_loader_output")
|
|
class CogView4ModelLoaderOutput(BaseInvocationOutput):
|
|
"""CogView4 base model loader output."""
|
|
|
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
|
glm_encoder: GlmEncoderField = OutputField(description=FieldDescriptions.glm_encoder, title="GLM Encoder")
|
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
|
|
|
|
|
@invocation(
|
|
"cogview4_model_loader",
|
|
title="Main Model - CogView4",
|
|
tags=["model", "cogview4"],
|
|
category="model",
|
|
version="1.0.0",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class CogView4ModelLoaderInvocation(BaseInvocation):
|
|
"""Loads a CogView4 base model, outputting its submodels."""
|
|
|
|
model: ModelIdentifierField = InputField(
|
|
description=FieldDescriptions.cogview4_model,
|
|
ui_type=UIType.CogView4MainModel,
|
|
input=Input.Direct,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
|
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
glm_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
|
glm_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
|
|
return CogView4ModelLoaderOutput(
|
|
transformer=TransformerField(transformer=transformer, loras=[]),
|
|
glm_encoder=GlmEncoderField(tokenizer=glm_tokenizer, text_encoder=glm_encoder),
|
|
vae=VAEField(vae=vae),
|
|
)
|