From 739f327d2d5988f04bbb7b73d12178c13f73821a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 20 Aug 2023 08:12:16 -0700 Subject: [PATCH] Shorter (#1582) * deleting lines * remove insert dims * if statement is never hit * bug fixes --- .gitignore | 3 +++ test/models/test_real_world.py | 2 +- test/test_dtype.py | 2 +- test/test_symbolic_jit.py | 5 ++++- test/test_symbolic_ops.py | 5 ++++- tinygrad/codegen/optimizer.py | 8 ++++---- tinygrad/tensor.py | 34 +++++++++------------------------- 7 files changed, 26 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index b834f91a73..d7c42b02c9 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ package.json package-lock.json temp *.csv +.coverage +coverage.xml +htmlcov diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 4da6a20a82..d28bac16da 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase): derandomize_model(model) @TinyJit def test(t, t2): return model(t, 801, t2).realize() - helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 967) + helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967) @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") def test_llama(self): diff --git a/test/test_dtype.py b/test/test_dtype.py index 82d6e3ddbd..50d2d8d4f9 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -92,7 +92,7 @@ class TestHalfDtype(unittest.TestCase): def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32) def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16) -@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends") +@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"] or OSX, "float64 is not supported by some backends") class TestDoubleDtype(unittest.TestCase): def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4]) def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 38bdd525f5..94f844d024 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -149,4 +149,7 @@ class TestSymbolicJit(unittest.TestCase): a = Tensor.rand(3, 7).reshape(3, vi) bad = Tensor.rand(4, 7).reshape(4, vi) with self.assertRaises(AssertionError): - add(a, bad) \ No newline at end of file + add(a, bad) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 4203262d50..ca3ef71de6 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -110,4 +110,7 @@ class TestSymbolicOps(unittest.TestCase): b = Tensor.rand(3, j) symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy() expected = f(a, b).cpu().numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) \ No newline at end of file + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index b2715fab00..218ed256ce 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -116,8 +116,8 @@ def hand_coded_optimizations(k:Linearizer): buf1 = k.bufs.index(k.reduceop.src[0].src[0].src[1]) buf0_strides = k.sts[buf0].real_strides() buf1_strides = k.sts[buf1].real_strides() - axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%16 == 0] - axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%16 == 0] + axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%16 == 0 and i < k.first_reduce] + axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%16 == 0 and i < k.first_reduce] if len(axis_buf0) and len(axis_buf1) and k.full_shape[k.first_reduce]%8 == 0 and (k.shape_len-k.first_reduce) == 1: if DEBUG >= 3: print("HIP TENSOR CORES", axis_buf0, axis_buf1) k.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA @@ -175,8 +175,8 @@ def hand_coded_optimizations(k:Linearizer): buf1 = k.bufs.index(k.reduceop.src[0].src[1]) buf0_strides = k.sts[buf0].real_strides() buf1_strides = k.sts[buf1].real_strides() - axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%8 == 0] - axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%8 == 0] + axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%8 == 0 and i < k.first_reduce] + axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%8 == 0 and i < k.first_reduce] if len(axis_buf0) and len(axis_buf1) and k.full_shape[k.first_reduce]%8 == 0 and (k.shape_len-k.first_reduce) == 1: if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1) k.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b26c354606..e015f558d0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -222,9 +222,6 @@ class Tensor: self.grad = Tensor(1, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): - if not t0.requires_grad: - del t0._ctx # TODO: does it help to delete this here ever? - continue assert (t0.grad is not None) grads = t0._ctx.backward(t0.grad.lazydata) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None @@ -251,12 +248,6 @@ class Tensor: # ***** movement hlops ***** - # NOTE: using slice is discouraged and things should migrate to pad and shrink - def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor: - arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) - padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) - return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) - # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element # - A slice i:j returns the elements with indices in [i, j) # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence @@ -357,6 +348,12 @@ class Tensor: ret = ret.permute(order=order) return ret + # NOTE: using slice is discouraged and things should migrate to pad and shrink + def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor: + arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) + padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) + return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) + def gather(self: Tensor, idx: Tensor, dim: int): assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" @@ -459,7 +456,7 @@ class Tensor: # ***** processing ops ***** - def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1, _insert_dims=tuple()) -> Tensor: + def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" @@ -467,10 +464,7 @@ class Tensor: if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding - xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)]) - # NOTE: _insert_dims is required because reduces can't be merged (yet) - prefix += _insert_dims - slc_prefix += [(0,x) for x in _insert_dims] + xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # slide by dilation xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) @@ -483,11 +477,7 @@ class Tensor: # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)] xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)]) - xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_)))) - if len(_insert_dims): - xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_)))) - prefix += _insert_dims - slc_prefix += [(0,x) for x in _insert_dims] + xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_)))) xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))) return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))]) @@ -518,12 +508,6 @@ class Tensor: rcout, oyx = cout//groups, x.shape[2:-len(HW)] x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) - # expand the channels with the pool - # TODO: this reduces the number of kernels, but it's slower! - #x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W) - #rcout, oy, ox = x.shape[2:5] - #x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7) - # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))