mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Tidy imports in other_impls.py
This commit is contained in:
@@ -7,7 +7,6 @@ import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
#################################################################################################
|
||||
@@ -26,7 +25,7 @@ def attention(
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
class Mlp(torch.nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
@@ -65,10 +64,10 @@ class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
@@ -95,9 +94,9 @@ class CLIPLayer(torch.nn.Module):
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(
|
||||
embed_dim,
|
||||
@@ -180,7 +179,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
@@ -202,7 +201,7 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -558,9 +557,9 @@ class T5LayerNorm(torch.nn.Module):
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
@@ -587,10 +586,10 @@ class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
|
||||
Reference in New Issue
Block a user