Files
tinygrad/models/transformer.py
2022-05-14 21:28:39 -07:00

73 lines
3.1 KiB
Python

import numpy as np
from tinygrad.tensor import Tensor
class TransformerBlock:
def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu()):
self.num_heads = num_heads
self.head_size = embed_dim // num_heads
assert self.head_size * self.num_heads == embed_dim
self.prenorm, self.act = prenorm, act
self.query = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.key = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.value = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.out = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
def attn(self, x):
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
query, key, value = [x.linear(*y) \
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query, self.key, self.value]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
def __call__(self, x):
if self.prenorm:
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(0.1)
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(0.1)
else:
x = x + self.attn(x).dropout(0.1)
x = x.layernorm().linear(*self.ln1)
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(0.1)
x = x.layernorm().linear(*self.ln2)
return x
class Transformer:
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
self.maxlen, self.syms = maxlen, syms
self.embed = Tensor.uniform(maxlen+syms, embed_dim, requires_grad=False)
self.tbs = []
for i in range(layers):
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
self.final = Tensor.uniform(embed_dim, syms)
def forward(self, x):
bs = x.shape[0]
xnp = x.cpu().data.astype(np.int32)
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
for i in range(x.shape[1]):
onehot[range(bs), i, i] = 1
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms)
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
x = x.sequential(self.tbs)
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
return x.reshape(shape=(bs, -1, x.shape[-1]))