fix broadcastdot

This commit is contained in:
George Hotz
2021-11-29 18:54:57 -05:00
parent 033b04494a
commit 58ed46963e
5 changed files with 31 additions and 28 deletions

View File

@@ -34,11 +34,11 @@ class ViTBlock:
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
def attn(self, x, bs):
def attn(self, x):
embed_dim = self.num_heads * self.head_size
query, key, value = [x.linear(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
@@ -49,21 +49,12 @@ class ViTBlock:
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
return attention.reshape(shape=(-1, embed_dim)).linear(self.final)
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
def __call__(self, x):
# bs x T x embed_dim
bs = x.shape[0]
embed_dim = self.num_heads * self.head_size
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
x = inputs.layernorm().linear(self.ln1)
x = inputs + self.attn(x, bs).dropout(0.1)
xin = x.layernorm().linear(self.ln2)
x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
return x.reshape(shape=(bs, -1, embed_dim))
x = x + self.attn(x.layernorm().linear(self.ln1)).dropout(0.1)
x = x + x.layernorm().linear(self.ln2).linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
return x
class ViT:
def __init__(self, embed_dim=192):
@@ -133,8 +124,8 @@ import ast
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
#url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
# junk
from PIL import Image

View File

@@ -6,7 +6,7 @@ import numpy as np
def fetch(url):
import requests, os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
if os.path.isfile(fp) and os.stat(fp).st_size > 0:
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None:
with open(fp, "rb") as f:
dat = f.read()
else:

View File

@@ -30,15 +30,11 @@ class TransformerBlock:
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
def __call__(self, x):
# bs x T x embed_dim
bs = x.shape[0]
def attn(self, x):
embed_dim = self.num_heads * self.head_size
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
query, key, value = [inputs.linear(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
query, key, value = [x.linear(y) \
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
@@ -49,7 +45,17 @@ class TransformerBlock:
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1)
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
def __call__(self, x):
# bs x T x embed_dim
bs = x.shape[0]
embed_dim = self.num_heads * self.head_size
#inputs = x.reshape(shape=(-1, embed_dim))
inputs = x
attention = self.attn(x)
x = inputs + attention.dropout(0.1)
x = layernorm(x, embed_dim).linear(self.ln1)
x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1)
x = layernorm(x, embed_dim).linear(self.ln2)

View File

@@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)

View File

@@ -203,8 +203,12 @@ class Tensor:
def pad2d(self, padding):
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
def dot(self, w):
return self.matmul(w)
def matmul(self, w):
if len(self.shape) > 2 and len(w.shape) == 2:
return self.reshape(shape=(-1, self.shape[-1]))._matmul(w).reshape(shape=list(self.shape[0:-1]) + [-1])
else:
return self._matmul(w)
dot = matmul
def _canonicalize_reduce_axis(self, axis):
if axis is None: axis = range(len(self.shape))