From 20059dc55b606a0dcff34ee45e83cb23b265bd3d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 24 Sep 2023 21:09:03 +0800 Subject: [PATCH] Make ShapeTracker Immutable (#1909) * ugh * ops test pass * fix shapetracker tests * sym shapetracker * shapetracker is a tuple of views now * from_shape * fix has variable shape * key isn't needed * post init assert --- docs/abstractions.py | 14 ++-- test/test_custom_function.py | 2 +- test/test_symbolic_shapetracker.py | 9 ++- test/unit/test_flopcounter.py | 6 +- test/unit/test_shapetracker.py | 101 ++++++++++++++++++----------- tinygrad/codegen/kernel.py | 11 ++-- tinygrad/codegen/linearizer.py | 2 +- tinygrad/codegen/optimizer.py | 11 ++-- tinygrad/lazy.py | 37 +++++------ tinygrad/ops.py | 17 +++-- tinygrad/shape/shapetracker.py | 61 ++++++++--------- 11 files changed, 143 insertions(+), 128 deletions(-) diff --git a/docs/abstractions.py b/docs/abstractions.py index 5d096a52f2..f3d52d9c35 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -326,22 +326,22 @@ void E_1(float* data0) { from tinygrad.shape.shapetracker import ShapeTracker # create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor -a = ShapeTracker((10, 10)) +a = ShapeTracker.from_shape((10, 10)) # you'll see it has one view. the (10, 1 are the strides) print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) # we can permute it, and the strides change -a.permute((1,0)) +a = a.permute((1,0)) print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) # we can then reshape it, and the strides change again # note how the permute stays applied -a.reshape((5,2,5,2)) +a = a.reshape((5,2,5,2)) print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)]) # now, if we were to reshape it to a (100,) shape tensor, we have to create a second view -a.reshape((100,)) +a = a.reshape((100,)) print(a) # ShapeTracker(shape=(100,), views=[ # View((5, 2, 5, 2), (2, 1, 20, 10), 0), # View((100,), (1,), 0)]) @@ -352,7 +352,7 @@ idx, _ = a.expr_idxs() print(idx.render()) # (((idx0%10)*10)+(idx0//10)) # of course, if we reshape it back, the indexes get simple again -a.reshape((10,10)) +a = a.reshape((10,10)) idx, _ = a.expr_idxs() print(idx.render()) # ((idx1*10)+idx0) @@ -362,11 +362,11 @@ print(a) # ShapeTracker(shape=(10, 10), views=[ # View((10, 10), (10, 1), 0)]) # ...until we simplify it! -a.simplify() +a = a.simplify() print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) # and now we permute it back -a.permute((1,0)) +a = a.permute((1,0)) print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) # and it's even contiguous diff --git a/test/test_custom_function.py b/test/test_custom_function.py index d41aeb65c2..445a995e76 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -43,7 +43,7 @@ class ATan2(Function): assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {}) + return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {}) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)) return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 368120a7c7..c3a818080a 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -6,18 +6,18 @@ from tinygrad.tensor import Tensor class TestSymbolic(unittest.TestCase): def test_symbolic_st(self): x = Variable("x", 1, 100) - st = ShapeTracker((x, 3)) + st = ShapeTracker.from_shape((x, 3)) assert st.shape == (x, 3) assert st.real_strides() == (3, 1) def test_expr_idxs(self): x = Variable("x", 1, 100) - st = ShapeTracker((x, 3)) + st = ShapeTracker.from_shape((x, 3)) idxs = [Variable("x", 0, 100), Variable("y", 0, 100)] e1, e2 = st.expr_idxs(idxs) assert e1.render() == "((x*3)+y)" assert e2.render() == "1" - st.permute((1, 0)) + st = st.permute((1, 0)) e1, e2 = st.expr_idxs(idxs) assert e1.render() == "((y*3)+x)" assert e2.render() == "1" @@ -142,8 +142,7 @@ class TestSymbolicShapeExpr(unittest.TestCase): idx = (gidx0, lidx1, Variable.num(1)) shape = (i+1, 8, 4) strides = (1, (i*4)+4, i+1) - view = View.create(shape, strides) - st = ShapeTracker(shape, [view]) + st = ShapeTracker((View.create(shape, strides), )) idx, valid = st.expr_idxs(idx) assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)" diff --git a/test/unit/test_flopcounter.py b/test/unit/test_flopcounter.py index 69507d52b4..bb9ae400cf 100644 --- a/test/unit/test_flopcounter.py +++ b/test/unit/test_flopcounter.py @@ -1,13 +1,13 @@ #!/usr/bin/env python import unittest from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer -from tinygrad.shape.view import View +from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import dtypes class TestFlopCounter(unittest.TestCase): def setUp(self): - self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),))) - self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),))) + self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,)))) + self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,)))) def test_flops_add(self): op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 82e52da12e..3786725da2 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -14,42 +14,51 @@ def shapetracker_getitem(st, val): class CheckingShapeTracker: def __init__(self, shape): - self.st = ShapeTracker(shape) + self.st = ShapeTracker.from_shape(shape) self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape) @property def shape(self): return self.t.shape - def simplify(self): self.st.simplify() + def simplify(self): + self.st = self.st.simplify() + return self def reshape(self, new_shape): - self.st.reshape(new_shape) + self.st = self.st.reshape(new_shape) self.t = self.t.reshape(new_shape) + return self def permute(self, axis): - self.st.permute(axis) + self.st = self.st.permute(axis) self.t = np.transpose(self.t, axis) + return self def expand(self, new_shape): - self.st.expand(new_shape) + self.st = self.st.expand(new_shape) self.t = np.broadcast_to(self.t, new_shape) + return self def flip(self, axis): - self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape)))) + self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape)))) self.t = np.flip(self.t, axis) + return self def shrink(self, arg): - self.st.shrink(arg) + self.st = self.st.shrink(arg) self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])] + return self def pad(self, arg): - self.st.pad(arg) + self.st = self.st.pad(arg) self.t = np.pad(self.t, arg, constant_values=-1) + return self def stride(self, arg): - self.st.stride(arg) + self.st = self.st.stride(arg) self.t = self.t[tuple([slice(None, None, x) for x in arg])] + return self def __getitem__(self, val): return self.t.flatten()[val] @@ -70,7 +79,7 @@ class CheckingShapeTracker: class TestRealIssues(unittest.TestCase): def test_reshape_doesnt_multiview(self): - self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)]) + self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),)) self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2)) assert len(self.st.views) == 1 @@ -78,27 +87,27 @@ class TestRealDoesntSimplify(unittest.TestCase): def tearDown(self): st = self.st.real_strides() print(st) - self.st.simplify() + self.st = self.st.simplify() assert len(self.st.views) != 1 assert None in st def test_1(self): - self.st = ShapeTracker((8, 6, 11), views=[ + self.st = ShapeTracker(( View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), - View.create((8, 6, 11), (66, 11, 1), 0, None)]) + View.create((8, 6, 11), (66, 11, 1), 0, None))) assert self.st.real_strides() == (33, None, 1) def test_2(self): - self.st = ShapeTracker((4, 4, 3, 3), views=[ + self.st = ShapeTracker(( View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), - View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)]) + View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) assert self.st.real_strides() == (None, 18, -3, -1) class TestRealStrides(unittest.TestCase): def test_1(self): - self.st = ShapeTracker((16, 32, 4), views=[ + self.st = ShapeTracker(( View.create((2048,), (1,), 0, ((0, 512),)), - View.create((16, 32, 4), (128, 4, 1), 0, None)]) + View.create((16, 32, 4), (128, 4, 1), 0, None))) st = self.st.real_strides() print(self.st, st) assert st == (None, 4, 1) @@ -106,27 +115,27 @@ class TestRealStrides(unittest.TestCase): class TestRealSimplifies(unittest.TestCase): def tearDown(self): st = self.st.real_strides() - self.st.simplify() + self.st = self.st.simplify() assert len(self.st.views) == 1 print(self.st.views[-1].strides, st) assert self.st.views[-1].strides == st def test_1(self): - self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[ + self.st = ShapeTracker(( View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None), - View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)]) + View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None))) def test_2(self): - self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[ + self.st = ShapeTracker(( View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None), - View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)]) + View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None))) class TestIndexExpressions2d(unittest.TestCase): def setUp(self): shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 offsets = [0, 1, 15, 28, 10000] - self.sts = [ShapeTracker(base_shape, [View.create(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets] + self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets] self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets] self.shapes = [shape for shape in shapes for offset in offsets] self.node_exprs = [] @@ -171,36 +180,52 @@ class TestIndexExpressions2d(unittest.TestCase): self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset) def test_permute(self): + new_st = [] for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st.permute((1, 0)) + st = st.permute((1, 0)) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset) + new_st.append(st) + self.sts = new_st def test_reshape(self): + new_st = [] for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st.reshape((base_shape[0], 1, base_shape[1])) + st = st.reshape((base_shape[0], 1, base_shape[1])) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) + new_st.append(st) + self.sts = new_st def test_reshape_expand(self): + new_st = [] for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st.reshape((base_shape[0], 1, base_shape[1])) - st.expand((base_shape[0], base_shape[1], base_shape[1])) + st = st.reshape((base_shape[0], 1, base_shape[1])) + st = st.expand((base_shape[0], base_shape[1], base_shape[1])) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) + new_st.append(st) + self.sts = new_st + def test_permute_reshape_1(self): # This tests multiple views + new_st = [] for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st.permute((1, 0)) - st.reshape((base_shape[0]//5, 1, base_shape[1]*5)) + st = st.permute((1, 0)) + st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5)) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) + new_st.append(st) + self.sts = new_st def test_permute_reshape_2(self): + new_st = [] for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st.permute((1, 0)) - st.reshape((1, base_shape[0]//5, base_shape[1]*5)) + st = st.permute((1, 0)) + st = st.reshape((1, base_shape[0]//5, base_shape[1]*5)) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) + new_st.append(st) + self.sts = new_st class TestSimplifyingShapeTracker(unittest.TestCase): def setUp(self): @@ -211,14 +236,14 @@ class TestSimplifyingShapeTracker(unittest.TestCase): # multiview simplify def test_expand_contract_simple(self): - self.st.expand((10, 10)) - self.st.reshape((100,)) + self.st = self.st.expand((10, 10)) + self.st = self.st.reshape((100,)) print(self.st.views) assert(len(self.st.views) == 2) - self.st.reshape((10, 10)) + self.st = self.st.reshape((10, 10)) print(self.st.views) - self.st.simplify() + self.st = self.st.simplify() print(self.st.views) assert(len(self.st.views) == 1) @@ -231,7 +256,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase): self.st.reshape((2, 5, 2, 5)) print(self.st.views) - self.st.simplify() + self.st = self.st.simplify() print(self.st.views) assert(len(self.st.views) == 1) @@ -243,7 +268,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase): assert(len(self.st.views) == 2) self.st.reshape((5, 20)) - self.st.simplify() + self.st = self.st.simplify() print(self.st.views) assert(len(self.st.views) == 2) @@ -387,7 +412,7 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase): self.st.reshape((1, 4)) self.st.shrink(((0, 1), (1, 3))) print(self.st.st) - self.st.simplify() + self.st = self.st.simplify() print(self.st.st) def test_case_2(self): self.st.stride( (1, 1, -2) ) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2e883212a6..ba63907f1a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -4,7 +4,7 @@ from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOp from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sint -from tinygrad.shape.view import strides_for_shape, View +from tinygrad.shape.view import strides_for_shape class LocalBuffer(NamedTuple): name: str @@ -43,11 +43,10 @@ class Kernel: self.reduceop = reduceops[0] if reduceops else None # create new shapetrackers inside this kernel, we will permute them - self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps]) - self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs] - for st in self.sts: st.simplify() + self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps]) + self.sts: List[ShapeTracker] = [x.st for x in self.bufs] - self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs) + self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs) # get earlybufs, before the one reduce op self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else [] @@ -67,7 +66,7 @@ class Kernel: def has_variable_shape(self) -> bool: for b in self.bufs: - if not all_int(b.views[-1].shape): return True + if not all_int(b.st.views[-1].shape): return True return False def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 9e3f1b0a37..42792edb21 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -260,7 +260,7 @@ class Linearizer(OptimizedKernel): # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduce: # TODO: the strides of this can be controlled - self.sts.append(ShapeTracker(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) + self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size()))) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 2a04b64cd8..499ce25b66 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -24,9 +24,12 @@ class OptimizedKernel(Kernel): # apply reshape and permute to all shapetrackers def reshape_and_permute(self, new_shape_fxn, axis): + new_sts = [] for st in self.sts: - if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st.permute(tuple(axis)) + if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) + if axis is not None: st = st.permute(tuple(axis)) + new_sts.append(st) + self.sts = new_sts # drops the final dimension def upcast(self): @@ -76,7 +79,7 @@ class OptimizedKernel(Kernel): else: rets[j].append((shapes[j][i], strides[j][i])) # do the reshapes - for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x])) + for i,x in enumerate(rets): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) # ******************** GPU simplifiers ******************** def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: @@ -120,7 +123,7 @@ class OptimizedKernel(Kernel): stride[j] = bst bst *= shp[j] - self.sts.append(ShapeTracker(tuple(shp), [View.create(tuple(shp), tuple(stride))])) + self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),))) self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size())) if DEBUG >= 4: print("aliasing buffer", self.sts[i]) self.local_alias[i] = self.bufs[-1] diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 24d9687070..d305fde3cc 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -7,7 +7,7 @@ import numpy as np from tinygrad.graph import log_op from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer -from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction +from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import Variable, sint from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer, buf_is_kernel_arg @@ -95,11 +95,10 @@ def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)]) for x in op.buffers: assert x.realized, "buffer isn't realized" - x.st.simplify() if isinstance(x.realized, RawConst): - replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, tuple(x.st.views))) + replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify())) elif x.realized in realized_bufs: - replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, tuple(x.st.views))) + replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify())) else: raise NotImplementedError(f"not handled {x}") return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs @@ -149,8 +148,8 @@ class LazyBuffer: def __repr__(self): return f"" @property def key(self): - if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key) - return (self.dtype, self.op.op, self.st.key, self.var_vals_key) + if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key) + return (self.dtype, self.op.op, self.st, self.var_vals_key) def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} @@ -195,7 +194,7 @@ class LazyBuffer: @staticmethod def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {}) + return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {}) # create a constant with the shape and dtype of self def const(self, val:Union[float, int]) -> LazyBuffer: @@ -204,11 +203,11 @@ class LazyBuffer: def contiguous(self:LazyBuffer) -> LazyBuffer: if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one - return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals) + return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals) @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker(x.shape, [View.create(x.shape)]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) + return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) def toCPU(self) -> np.ndarray: assert self.dtype.np, f"{self.dtype} is not supported in toCPU" @@ -242,7 +241,7 @@ class LazyBuffer: # remove the buffers from any (childless) BinaryOps that feed into this srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore - return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) + return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: @@ -257,7 +256,7 @@ class LazyBuffer: def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) - return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals) def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. @@ -282,18 +281,18 @@ class LazyBuffer: self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? self.op.src[0].var_vals = self.var_vals return self.op.src[0].reshape(arg) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg) + return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg) def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: if all(b == 0 and e == 0 for b,e in arg): return self if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg) + return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg) def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg) + return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg) def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if arg == tuple(range(len(self.shape))): return self @@ -315,19 +314,19 @@ class LazyBuffer: if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer): if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? - return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg) + return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape) + return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg) def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg) + return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg) def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: - local_st = ShapeTracker(self.shape).stride(arg) + local_st = ShapeTracker.from_shape(self.shape).stride(arg) if self.shape == local_st.shape and local_st.contiguous: return self if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg))) - return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg) + return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg) @property def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a892a9732f..93d46fe881 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -3,7 +3,7 @@ import time, importlib, inspect, functools, pathlib from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored -from tinygrad.shape.view import View +from tinygrad.shape.shapetracker import ShapeTracker from dataclasses import dataclass if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer @@ -25,13 +25,13 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp class MemBuffer: idx: int dtype: DType - views: Tuple[View, ...] + st: ShapeTracker @dataclass(frozen=True) class ConstBuffer: val: Any dtype: DType - views: Tuple[View, ...] + st: ShapeTracker class LazyOp: __slots__ = "op", "src", "arg", "buffers", "__weakref__" @@ -101,8 +101,8 @@ Device = _Device() # **************** for Interpreted Buffers **************** -def apply_shapetracker(fxn_for_op, ret, views): - for v in views: +def apply_shapetracker(fxn_for_op, ret, st): + for v in st.views: real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) # first, we apply the offset @@ -133,7 +133,7 @@ class Interpreted: def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs): if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op: assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch" - return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.views)) + return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st)) if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) created_context = context is None @@ -162,7 +162,7 @@ class FlopCounter: self.flops, ret = 0, self.flops return ret shape_fxn_for_op: Dict[Op, Callable] = { - LoadOps.BUFFER: lambda arg: (arg.views[-1].shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.views[-1].shape, arg.dtype, 0), + LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0), UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST}, **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, @@ -241,8 +241,7 @@ class Compiled: for i,a in enumerate(inputs): # TODO: if this is contiguous it's fine if a == output.realized: - views = [x.arg.views for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1] - if any(len(v) > 1 or not v[0].contiguous for v in views): + if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1): output.realized = None break diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 6305a8268f..d58bf95ccb 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,7 +1,8 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations import functools -from typing import Tuple, Union, List, Optional, cast +from dataclasses import dataclass +from typing import Tuple, List, Optional, cast from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint from tinygrad.shape.view import View @@ -48,7 +49,7 @@ def expr_idxs(view:View, idxs) -> Node: @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: if vm2.mask: return None # this isn't supported yet - mst = ShapeTracker(vm1.shape, [vm2, vm1]) + mst = ShapeTracker((vm2, vm1)) strides = mst.real_strides() if None in strides: return None return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask) @@ -63,12 +64,13 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node: acc *= d return Variable.sum(ret) +@dataclass(frozen=True) class ShapeTracker: - __slots__ = "views" - def __init__(self, shape:Union[ShapeTracker, Tuple[sint, ...]], views:Optional[List[View]]=None): - self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View.create(shape)] - def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" - def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) + views: Tuple[View, ...] + def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views" + + @staticmethod + def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),)) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @@ -76,9 +78,6 @@ class ShapeTracker: @property def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape - @property - def key(self) -> Tuple[View, ...]: return tuple(self.views) - # this is the real size (ish) def size(self): return self.views[-1].size() @@ -113,13 +112,13 @@ class ShapeTracker: idx = expr_node(v, idx) return idx, valid - def simplify(self): + def simplify(self) -> ShapeTracker: if len(self.views) >= 2: new_view = merge_views(self.views[-2], self.views[-1]) if new_view: if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}") - self.views = self.views[:-2] + [new_view] - self.simplify() + return ShapeTracker(self.views[:-2] + (new_view,)).simplify() + return self def expr_idxs(self, idxs=None): if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] @@ -137,38 +136,30 @@ class ShapeTracker: # *** under this line are the movement ops *** - def pad(self, arg: Tuple[Tuple[int, int], ...]): - self.views[-1] = self.views[-1].pad(arg) - return self + def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), )) - def shrink(self, arg: Tuple[Tuple[sint, sint], ...]): - self.views[-1] = self.views[-1].shrink(arg) - return self + def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) - def expand(self, new_shape: Tuple[sint, ...]): - self.views[-1] = self.views[-1].expand(new_shape) - return self + def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) - def permute(self, axis: Tuple[int, ...]): - self.views[-1] = self.views[-1].permute(axis) - return self + def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) - def stride(self, mul: Tuple[int, ...]): - self.views[-1] = self.views[-1].stride(mul) - return self + def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), )) - def reshape(self, new_shape: Tuple[sint, ...]): + def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: new_view = self.views[-1].reshape(new_shape) if new_view is None: extra_view = View.create(new_shape) # last chance to merge. TODO: move into View if (merged_view := merge_views(self.views[-1], extra_view)) is not None: - self.views[-1] = merged_view - else: - self.views.append(extra_view) - else: - self.views[-1] = new_view - return self + return ShapeTracker(self.views[0:-1] + (merged_view,)) + return ShapeTracker(self.views + (extra_view, )) + return ShapeTracker(self.views[0:-1] + (new_view,)) # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape # TODO: if we remove movementops from lazy.py we can delete this