mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 13:07:55 -05:00
93 lines
4.1 KiB
Python
93 lines
4.1 KiB
Python
import torch
|
|
from transformers import GlmModel, PreTrainedTokenizerFast
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
|
from invokeai.app.invocations.model import GlmEncoderField
|
|
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|
CogView4ConditioningInfo,
|
|
ConditioningFieldData,
|
|
)
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
|
|
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
|
|
|
|
|
|
@invocation(
|
|
"cogview4_text_encoder",
|
|
title="Prompt - CogView4",
|
|
tags=["prompt", "conditioning", "cogview4"],
|
|
category="conditioning",
|
|
version="1.0.0",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class CogView4TextEncoderInvocation(BaseInvocation):
|
|
"""Encodes and preps a prompt for a cogview4 image."""
|
|
|
|
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
|
glm_encoder: GlmEncoderField = InputField(
|
|
title="GLM Encoder",
|
|
description=FieldDescriptions.glm_encoder,
|
|
input=Input.Connection,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
|
|
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
|
|
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
|
|
conditioning_name = context.conditioning.save(conditioning_data)
|
|
return CogView4ConditioningOutput.build(conditioning_name)
|
|
|
|
def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
|
prompt = [self.prompt]
|
|
|
|
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
|
|
with (
|
|
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
|
|
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
|
|
):
|
|
context.util.signal_progress("Running GLM text encoder")
|
|
assert isinstance(glm_text_encoder, GlmModel)
|
|
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
|
|
|
|
text_inputs = glm_tokenizer(
|
|
prompt,
|
|
padding="longest",
|
|
max_length=max_seq_len,
|
|
truncation=True,
|
|
add_special_tokens=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = glm_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
assert isinstance(text_input_ids, torch.Tensor)
|
|
assert isinstance(untruncated_ids, torch.Tensor)
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
text_input_ids, untruncated_ids
|
|
):
|
|
removed_text = glm_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
|
context.logger.warning(
|
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
|
f" {max_seq_len} tokens: {removed_text}"
|
|
)
|
|
|
|
current_length = text_input_ids.shape[1]
|
|
pad_length = (16 - (current_length % 16)) % 16
|
|
if pad_length > 0:
|
|
pad_ids = torch.full(
|
|
(text_input_ids.shape[0], pad_length),
|
|
fill_value=glm_tokenizer.pad_token_id,
|
|
dtype=text_input_ids.dtype,
|
|
device=text_input_ids.device,
|
|
)
|
|
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
|
prompt_embeds = glm_text_encoder(
|
|
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
|
|
).hidden_states[-2]
|
|
|
|
assert isinstance(prompt_embeds, torch.Tensor)
|
|
return prompt_embeds
|