Tidy imports in other_impls.py

This commit is contained in:
Ryan Dick
2024-10-23 15:24:21 +00:00
parent 9f7b5f7a85
commit 155bf13d2b

View File

@@ -7,7 +7,6 @@ import math
from typing import Callable, Optional
import torch
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
@@ -26,7 +25,7 @@ def attention(
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
class Mlp(nn.Module):
class Mlp(torch.nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
@@ -65,10 +64,10 @@ class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None):
q = self.q_proj(x)
@@ -95,9 +94,9 @@ class CLIPLayer(torch.nn.Module):
device,
):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(
embed_dim,
@@ -180,7 +179,7 @@ class CLIPTextModel_(torch.nn.Module):
dtype,
device,
)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.final_layer_norm = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
@@ -202,7 +201,7 @@ class CLIPTextModel(torch.nn.Module):
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
@@ -558,9 +557,9 @@ class T5LayerNorm(torch.nn.Module):
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
@@ -587,10 +586,10 @@ class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias: