mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 20:18:07 -05:00
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# Initially pulled from https://github.com/black-forest-labs/flux
|
|
|
|
from torch import Tensor, nn
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
|
|
class HFEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoder: PreTrainedModel,
|
|
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
|
is_clip: bool,
|
|
max_length: int,
|
|
):
|
|
super().__init__()
|
|
self.max_length = max_length
|
|
self.is_clip = is_clip
|
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
|
self.tokenizer = tokenizer
|
|
self.hf_module = encoder
|
|
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
|
|
|
def forward(self, text: list[str]) -> Tensor:
|
|
batch_encoding = self.tokenizer(
|
|
text,
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
|
|
outputs = self.hf_module(
|
|
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
|
|
attention_mask=None,
|
|
output_hidden_states=False,
|
|
)
|
|
return outputs[self.output_key]
|