From 58ed46963efab46bbe17439b74f82a388fa593c2 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 29 Nov 2021 18:54:57 -0500 Subject: [PATCH] fix broadcastdot --- examples/vit.py | 25 ++++++++----------------- extra/utils.py | 2 +- models/transformer.py | 22 ++++++++++++++-------- test/test_ops.py | 2 ++ tinygrad/tensor.py | 8 ++++++-- 5 files changed, 31 insertions(+), 28 deletions(-) diff --git a/examples/vit.py b/examples/vit.py index ab54dcf7c3..c619645466 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -34,11 +34,11 @@ class ViTBlock: self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) - def attn(self, x, bs): + def attn(self, x): embed_dim = self.num_heads * self.head_size query, key, value = [x.linear(y) \ - .reshape(shape=(bs, -1, self.num_heads, self.head_size)) \ + .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \ for y in [self.query_dense, self.key_dense, self.value_dense]] query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) @@ -49,21 +49,12 @@ class ViTBlock: weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - return attention.reshape(shape=(-1, embed_dim)).linear(self.final) + return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final) def __call__(self, x): - # bs x T x embed_dim - bs = x.shape[0] - embed_dim = self.num_heads * self.head_size - inputs = x.reshape(shape=(-1, embed_dim)) - - # run multi head attention (bs, T, num_heads, head_size) - x = inputs.layernorm().linear(self.ln1) - x = inputs + self.attn(x, bs).dropout(0.1) - - xin = x.layernorm().linear(self.ln2) - x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1) - return x.reshape(shape=(bs, -1, embed_dim)) + x = x + self.attn(x.layernorm().linear(self.ln1)).dropout(0.1) + x = x + x.layernorm().linear(self.ln2).linear(self.ff1).gelu().linear(self.ff2).dropout(0.1) + return x class ViT: def __init__(self, embed_dim=192): @@ -133,8 +124,8 @@ import ast lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt") lbls = ast.literal_eval(lbls.decode('utf-8')) -url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" -#url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0" +#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" +url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0" # junk from PIL import Image diff --git a/extra/utils.py b/extra/utils.py index 01ea32cf70..0161f4e899 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -6,7 +6,7 @@ import numpy as np def fetch(url): import requests, os, hashlib, tempfile fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest()) - if os.path.isfile(fp) and os.stat(fp).st_size > 0: + if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None: with open(fp, "rb") as f: dat = f.read() else: diff --git a/models/transformer.py b/models/transformer.py index f56a77d9e8..693321c2dd 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -30,15 +30,11 @@ class TransformerBlock: self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) - def __call__(self, x): - # bs x T x embed_dim - bs = x.shape[0] + def attn(self, x): embed_dim = self.num_heads * self.head_size - inputs = x.reshape(shape=(-1, embed_dim)) - # run multi head attention (bs, T, num_heads, head_size) - query, key, value = [inputs.linear(y) \ - .reshape(shape=(bs, -1, self.num_heads, self.head_size)) \ + query, key, value = [x.linear(y) \ + .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \ for y in [self.query_dense, self.key_dense, self.value_dense]] query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) @@ -49,7 +45,17 @@ class TransformerBlock: weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1) + return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final) + + def __call__(self, x): + # bs x T x embed_dim + bs = x.shape[0] + embed_dim = self.num_heads * self.head_size + #inputs = x.reshape(shape=(-1, embed_dim)) + inputs = x + attention = self.attn(x) + + x = inputs + attention.dropout(0.1) x = layernorm(x, embed_dim).linear(self.ln1) x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1) x = layernorm(x, embed_dim).linear(self.ln2) diff --git a/test/test_ops.py b/test/test_ops.py index 83c53e8da3..70f770481d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -83,6 +83,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + def test_broadcastdot(self): + helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 530fde8041..5d4f8e4785 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -203,8 +203,12 @@ class Tensor: def pad2d(self, padding): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] - def dot(self, w): - return self.matmul(w) + def matmul(self, w): + if len(self.shape) > 2 and len(w.shape) == 2: + return self.reshape(shape=(-1, self.shape[-1]))._matmul(w).reshape(shape=list(self.shape[0:-1]) + [-1]) + else: + return self._matmul(w) + dot = matmul def _canonicalize_reduce_axis(self, axis): if axis is None: axis = range(len(self.shape))