mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix broadcastdot
This commit is contained in:
@@ -34,11 +34,11 @@ class ViTBlock:
|
||||
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
def attn(self, x, bs):
|
||||
def attn(self, x):
|
||||
embed_dim = self.num_heads * self.head_size
|
||||
|
||||
query, key, value = [x.linear(y) \
|
||||
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
|
||||
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
@@ -49,21 +49,12 @@ class ViTBlock:
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
|
||||
return attention.reshape(shape=(-1, embed_dim)).linear(self.final)
|
||||
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
bs = x.shape[0]
|
||||
embed_dim = self.num_heads * self.head_size
|
||||
inputs = x.reshape(shape=(-1, embed_dim))
|
||||
|
||||
# run multi head attention (bs, T, num_heads, head_size)
|
||||
x = inputs.layernorm().linear(self.ln1)
|
||||
x = inputs + self.attn(x, bs).dropout(0.1)
|
||||
|
||||
xin = x.layernorm().linear(self.ln2)
|
||||
x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
|
||||
return x.reshape(shape=(bs, -1, embed_dim))
|
||||
x = x + self.attn(x.layernorm().linear(self.ln1)).dropout(0.1)
|
||||
x = x + x.layernorm().linear(self.ln2).linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
|
||||
return x
|
||||
|
||||
class ViT:
|
||||
def __init__(self, embed_dim=192):
|
||||
@@ -133,8 +124,8 @@ import ast
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
|
||||
url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
#url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||
|
||||
# junk
|
||||
from PIL import Image
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
def fetch(url):
|
||||
import requests, os, hashlib, tempfile
|
||||
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if os.path.isfile(fp) and os.stat(fp).st_size > 0:
|
||||
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None:
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
|
||||
@@ -30,15 +30,11 @@ class TransformerBlock:
|
||||
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
bs = x.shape[0]
|
||||
def attn(self, x):
|
||||
embed_dim = self.num_heads * self.head_size
|
||||
inputs = x.reshape(shape=(-1, embed_dim))
|
||||
|
||||
# run multi head attention (bs, T, num_heads, head_size)
|
||||
query, key, value = [inputs.linear(y) \
|
||||
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
|
||||
query, key, value = [x.linear(y) \
|
||||
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
@@ -49,7 +45,17 @@ class TransformerBlock:
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
|
||||
x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1)
|
||||
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
bs = x.shape[0]
|
||||
embed_dim = self.num_heads * self.head_size
|
||||
#inputs = x.reshape(shape=(-1, embed_dim))
|
||||
inputs = x
|
||||
attention = self.attn(x)
|
||||
|
||||
x = inputs + attention.dropout(0.1)
|
||||
x = layernorm(x, embed_dim).linear(self.ln1)
|
||||
x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1)
|
||||
x = layernorm(x, embed_dim).linear(self.ln2)
|
||||
|
||||
@@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
|
||||
@@ -203,8 +203,12 @@ class Tensor:
|
||||
def pad2d(self, padding):
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
def dot(self, w):
|
||||
return self.matmul(w)
|
||||
def matmul(self, w):
|
||||
if len(self.shape) > 2 and len(w.shape) == 2:
|
||||
return self.reshape(shape=(-1, self.shape[-1]))._matmul(w).reshape(shape=list(self.shape[0:-1]) + [-1])
|
||||
else:
|
||||
return self._matmul(w)
|
||||
dot = matmul
|
||||
|
||||
def _canonicalize_reduce_axis(self, axis):
|
||||
if axis is None: axis = range(len(self.shape))
|
||||
|
||||
Reference in New Issue
Block a user