diff --git a/examples/gpt2.py b/examples/gpt2.py index a65587f464..4783cd7e72 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -112,11 +112,9 @@ class Transformer: def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None): _bsz, seqlen = tokens.shape if seqlen == 1 and start_pos > 0 and getenv("JIT"): - start_pos_var = Variable("start_pos", 1, MAX_CONTEXT) + start_pos_var = Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos) pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen))) - pos.lazydata.var_vals[start_pos_var] = start_pos for k,v in self.kv_caches.items(): - v.lazydata.var_vals[start_pos_var] = start_pos self.kv_caches[k] = v.reshape(v.shape[0], start_pos_var, v.shape[2], v.shape[3]) logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches) return logit_or_softmax diff --git a/examples/llama.py b/examples/llama.py index 55d31888f5..527156b7d0 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -18,7 +18,7 @@ from tinygrad.ops import GlobalCounters from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE from tinygrad.shape.symbolic import Variable, sym_infer -JIT = getenv("JIT", 0 if CI else Device.DEFAULT in JIT_SUPPORTED_DEVICE) +JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE)) # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): @@ -82,7 +82,7 @@ class Attention: keys, values = xk, xv else: assert cache_k is not None and cache_v is not None, "no cache" - assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})" + assert start_pos == (cache_k.shape[1].val if isinstance(cache_k.shape[1], Variable) else cache_k.shape[1]) == (cache_v.shape[1].val if isinstance(cache_v.shape[1], Variable) else cache_v.shape[1]), f"cache has wrong shape, {start_pos=}, {cache_k.shape[1]=}, {cache_v.shape[1]=}" assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?" keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1) @@ -117,12 +117,9 @@ class TransformerBlock: bsz, seqlen, _ = x.shape if JIT and mask is None: assert cache_k is not None and cache_v is not None, "no cache" - pos = Variable("pos", 1, 1024) + pos = Variable("pos", 1, 1024).bind(start_pos) cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3]) cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3]) - # need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache - cache_k.lazydata.var_vals[pos] = start_pos - cache_v.lazydata.var_vals[pos] = start_pos output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx) h = x + output @@ -150,14 +147,12 @@ class Transformer: def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None): _bsz, seqlen = tokens.shape if seqlen == 1 and start_pos > 0 and JIT: - pos = Variable("pos", 1, 1024) + pos = Variable("pos", 1, 1024).bind(start_pos) # get only the part of freqs_cis that we are using. freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4]))) - freqs_cis.lazydata.var_vals[pos] = start_pos h = self.tok_embeddings_jitted(tokens) for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)): - h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos: start_pos}) - # TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place + h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos.unbind()[0]: start_pos}) self.kv_caches[i] = (cache_k, cache_v) return self.postprocess_jitted(h, temperature) else: diff --git a/test/test_custom_function.py b/test/test_custom_function.py index a97be8ba5e..e263e22d08 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -43,7 +43,7 @@ class ATan2(Function): assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {}) + return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype)) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)) return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 5f5ec425f2..f2f0fd87e0 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -11,8 +11,8 @@ class TestSymbolicJit(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy() expected = f(a).numpy() @@ -20,12 +20,12 @@ class TestSymbolicJit(unittest.TestCase): assert len(jf.jit_cache) == 1 def test_reshape_inside_plus1(self): - vi = Variable("i", 1, 10) def f(a, jit=False, jit_ctx=None): - if jit: a = a.reshape(3, vi) + if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1])) return (a+1).realize() jf = TinyJit(f) for i in range(1, 5): + vi = Variable("i", 1, 10) a = Tensor.rand(3, i) symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy() expected = f(a).numpy() @@ -35,8 +35,8 @@ class TestSymbolicJit(unittest.TestCase): def test_add(self): def f(a, b): return (a+b).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(3, i) symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() @@ -47,8 +47,8 @@ class TestSymbolicJit(unittest.TestCase): def test_matmul(self): def f(a, b): return (a@b).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(i, 5) symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() @@ -63,7 +63,7 @@ class TestSymbolicJit(unittest.TestCase): return s jf = TinyJit(f) for i in range(1, 5): - vi = Variable("i", 1, 10) + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(i, 5) symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() @@ -74,8 +74,8 @@ class TestSymbolicJit(unittest.TestCase): def test_attention(self): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) q = Tensor.rand(2, 1, 4, 8) k = Tensor.rand(2, i, 4, 8) v = Tensor.rand(2, i, 4, 8) @@ -87,8 +87,8 @@ class TestSymbolicJit(unittest.TestCase): def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(i, 3) b = Tensor.rand(2, 3) symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() @@ -99,8 +99,8 @@ class TestSymbolicJit(unittest.TestCase): def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(3, 2) symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy() @@ -111,10 +111,10 @@ class TestSymbolicJit(unittest.TestCase): def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(i, 3) b = Tensor.rand(j, 3) symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() @@ -125,10 +125,10 @@ class TestSymbolicJit(unittest.TestCase): def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(3, i) b = Tensor.rand(3, j) symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() @@ -139,10 +139,10 @@ class TestSymbolicJit(unittest.TestCase): def test_two_vars_plus1(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(i, 3) b = Tensor.rand(3, j) symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() @@ -153,13 +153,14 @@ class TestSymbolicJit(unittest.TestCase): def test_jit_symbolic_shape_mismatch(self): @TinyJit def add(a, b): return (a+b).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i).reshape(3, vi) b = Tensor.rand(3, i).reshape(3, vi) c = add(a, b) - a = Tensor.rand(3, 7).reshape(3, vi) - bad = Tensor.rand(4, 7).reshape(4, vi) + vi2 = Variable("i", 1, 10).bind(7) + a = Tensor.rand(3, 7).reshape(3, vi2) + bad = Tensor.rand(4, 7).reshape(4, vi2) with self.assertRaises(AssertionError): add(a, bad) @@ -167,11 +168,10 @@ class TestSymbolicJit(unittest.TestCase): # shrink is a movement, so we pair it with a simple function to test the JIT interaction def f(a): return (a+1).realize() jf = TinyJit(f) - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic.lazydata.var_vals[vi] = i symbolic = jf(symbolic).numpy() expected = f(a.shrink(((3,5),(i,i+2)))).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 331fed289c..c446899e80 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -10,8 +10,8 @@ import numpy as np class TestSymbolicOps(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy() expected = f(a).numpy() @@ -19,8 +19,8 @@ class TestSymbolicOps(unittest.TestCase): def test_add(self): def f(a, b): return (a+b).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(3, i) symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() @@ -29,26 +29,18 @@ class TestSymbolicOps(unittest.TestCase): def test_matmul(self): def f(a, b): return (a@b).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(i, 5) symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() expected = f(a, b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_matmul_same_var_different_val(self): - def f(a, b): return (a@b).realize() - vi = Variable("i", 1, 10) - a = Tensor.rand(3, 4) - b = Tensor.rand(7, 5) - with self.assertRaises(AssertionError): - f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - def test_attention(self, dropout_p=0.0): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) q = Tensor.rand(2, 1, 4, 8) k = Tensor.rand(2, i, 4, 8) v = Tensor.rand(2, i, 4, 8) @@ -65,8 +57,8 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(i, 3) b = Tensor.rand(2, 3) symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() @@ -75,8 +67,8 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(3, i) b = Tensor.rand(3, 2) symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy() @@ -85,10 +77,10 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(i, 3) b = Tensor.rand(j, 3) symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() @@ -97,10 +89,10 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(3, i) b = Tensor.rand(3, j) symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() @@ -109,10 +101,10 @@ class TestSymbolicOps(unittest.TestCase): def test_two_vars_plus1(self): def f(a, b): return (a@b+1).realize() - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) for i in range(1, 5): for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) a = Tensor.rand(i, 3) b = Tensor.rand(3, j) symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() @@ -120,11 +112,10 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_shrink(self): - vi = Variable("i", 1, 10) for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic.lazydata.var_vals[vi] = i symbolic = symbolic.numpy() expected = a.shrink(((3,5),(i,i+2))).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index c3a818080a..1e0b4feec7 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -22,107 +22,130 @@ class TestSymbolic(unittest.TestCase): assert e1.render() == "((y*3)+x)" assert e2.render() == "1" - def test_cat_strides(self): - i = Variable("i", 1, 5) - j = Variable("j", 1, 5) - k = Variable("k", 1, 5) + def test_cat_dim0_strides(self): + i = Variable("i", 1, 5).bind(3) + j = Variable("j", 1, 5).bind(3) + k = Variable("k", 1, 5).bind(3) t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) st = t.lazydata.st assert st.shape == (i+j+k, 4) assert st.real_strides() == (4, 1) - t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) - st = t.lazydata.st - assert st.shape == (3, i+j+k) - assert st.real_strides() == (i+j+k, 1) t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) st = t.lazydata.st assert st.shape == (2*i+3, 3) assert st.real_strides() == (3, 1) + def test_cat_dim1_strides(self): + i = Variable("i", 1, 5).bind(4) + j = Variable("j", 1, 5).bind(4) + k = Variable("k", 1, 5).bind(4) + t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) + st = t.lazydata.st + assert st.shape == (3, i+j+k) + assert st.real_strides() == (i+j+k, 1) + +class TestSymbolicVarVals(unittest.TestCase): + def test_var_vals_empty(self): + assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {} + + def test_var_vals_shape(self): + x = Variable("x", 1, 100).bind(3) + assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3} + + def test_var_vals_offset(self): + x = Variable("x", 1, 100).bind(3) + st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3))) + assert st.real_offset() == x * 3 + assert st.var_vals == {Variable("x", 1, 100): 3} + + def test_var_vals_mask(self): + x = Variable("x", 1, 100).bind(3) + view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4))) + st = ShapeTracker(views=(view,)) + assert st.var_vals == {Variable("x", 1, 100): 3} + + def test_var_vals_complex(self): + x = Variable("x", 1, 100).bind(3) + y = Variable("y", 1, 100).bind(4) + z = Variable("z", 1, 100).bind(5) + st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3))) + assert st.real_offset() == y * z + assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5} + + def test_shrink_reshape(self): + x = Variable("x", 1, 100).bind(3) + st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5))) + st = st.reshape((3*4*3,)) + assert st.var_vals == {Variable("x", 1, 100): 3} + +class TestShapeTrackerUnbind(unittest.TestCase): + def test_view_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(3) + assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4)) + + def test_reshape_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(3) + t = Tensor.rand(3, 4).reshape(bv, 4) + assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),)) + + def test_shrink_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(2) + t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4))) + assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) + class TestSymbolicReshape(unittest.TestCase): def test_reshape_into_symbols_simple(self): - vi = Variable("i", 1, 5) for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) t = Tensor.rand(i, 4).reshape(vi, 4) assert t.shape == (vi, 4) - assert t.lazydata.var_vals[vi] == i t = Tensor.rand(i, 6).reshape(vi, 2, 3) assert t.shape == (vi, 2, 3) - assert t.lazydata.var_vals[vi] == i def test_reshape_symbols_reshape_ints(self): - vi = Variable("i", 1, 5) for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) t = Tensor.rand(i, 4).reshape(vi, 4) assert t.shape == (vi, 4) - assert t.lazydata.var_vals == {vi: i} t = t.reshape(i, 4) assert t.shape == (i, 4) - assert t.lazydata.var_vals == {vi: i} - - def test_reshape_reuse_var_same_value_ok(self): - vi = Variable("i", 1, 5) - for i in range(1, 6): - a = Tensor.rand(i, 4).reshape(vi, 4) - b = Tensor.rand(i, 3).reshape(vi, 3) - assert a.lazydata.var_vals[vi] == i - assert b.lazydata.var_vals[vi] == i - - def test_reshape_reuse_var_different_value_ok(self): - vi = Variable("i", 1, 10) - for i in range(1, 6): - a = Tensor.rand(i, 4).reshape(vi, 2) - b = Tensor.rand(i, 3).reshape(vi, 3) - # a and b have different values of vi - assert a.lazydata.var_vals[vi] == 2 * i - assert b.lazydata.var_vals[vi] == i def test_reshape_into_symbols_bad_shape(self): - vi = Variable("i", 1, 10) - vj = Variable("j", 1, 10) - with self.assertRaises(AssertionError): - t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables - with self.assertRaises(AssertionError): - t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions - with self.assertRaises(AssertionError): - t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values + vi = Variable("i", 1, 10).bind(4) with self.assertRaises(AssertionError): t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape - with self.assertRaises(AssertionError): - t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4) - with self.assertRaises(AssertionError): - t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4) with self.assertRaises(AssertionError): t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node def test_two_symbol_reshape(self): - vi = Variable("i", 1, 5) - vj = Variable("j", 1, 5) for i in range(1, 6): for j in range(1, 6): - t1 = Tensor.rand(i, 5).reshape(vi, 5) - t2 = Tensor.rand(5, j).reshape(5, vj) - t = t1@t2 + vi = Variable("i", 1, 5).bind(i) + vj = Variable("j", 1, 5).bind(j) + t = Tensor.rand(i, j).reshape(vi, vj) assert t.shape == (vi, vj) - t = t.reshape(1, vi*vj) - assert t.shape == (1, vi*vj) + # NOTE: this is currently not allowed + # t = t.reshape(1, vi*vj) + # assert t.shape == (1, vi*vj) t = t.reshape(vj, vi) assert t.shape == (vj, vi) class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): + # TODO: enfore expand only into bound variables vi = Variable("i", 1, 5) vj = Variable("j", 1, 5) a = Tensor([[1], [2], [3]]).expand((3, vi)) assert a.shape == (3, vi) - assert a.lazydata.var_vals == {} a = a.reshape(3, vi, 1).expand((3, vi, vj)) assert a.shape == (3, vi, vj) - assert a.lazydata.var_vals == {} def test_plus_expands_constant(self): - vi = Variable("i", 1, 5) for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) a = Tensor.rand(3, i).reshape(3, vi) a = a + 1 assert a.shape == (3, vi) @@ -146,23 +169,5 @@ class TestSymbolicShapeExpr(unittest.TestCase): idx, valid = st.expr_idxs(idx) assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)" -class TestShapeTrackerVarVals(unittest.TestCase): - def test_reshape_reshape_updates_var_vals(self): - vi = Variable("i", 1, 5) - vj = Variable("j", 1, 5) - t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj) - assert t.lazydata.var_vals == {vi: 4, vj: 3} - - def test_lazy_check_var_vals(self): - vi = Variable("i", 1, 5) - a = Tensor.rand(3, 4).reshape(3, vi) - b = Tensor.rand(5, 6).reshape(vi, 6) - assert a.lazydata.var_vals == {vi: 4} - assert b.lazydata.var_vals == {vi: 5} - c = a@b - # shapetracker works with symbolic shape and doesn't check the underlying variable values - assert c.shape == (3, 6) - assert c.lazydata.var_vals == {vi: 4} - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index d6670bbb84..cd62699e6a 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -11,7 +11,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename from tinygrad.codegen.optimizer import OptimizedKernel from tinygrad.codegen.kernel import LocalBuffer -from tinygrad.lazy import var_vals_from_ast +from tinygrad.lazy import vars_from_ast from tinygrad.features.image import to_image_idx # bottom ones are asm only @@ -164,7 +164,7 @@ class Linearizer(OptimizedKernel): if isinstance(buf, MemBuffer): self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # add var vals - for var in sorted(var_vals_from_ast(self.ast), key=lambda k: k.key): + for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key): assert var.expr is not None self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) # define local buffers diff --git a/tinygrad/features/kopt.py b/tinygrad/features/kopt.py index 33786b2ae1..e6eb4f3ebd 100644 --- a/tinygrad/features/kopt.py +++ b/tinygrad/features/kopt.py @@ -2,7 +2,7 @@ from typing import Callable import time from tinygrad.codegen.linearizer import Linearizer from tinygrad.helpers import DEBUG, prod, getenv -from tinygrad.lazy import var_vals_from_ast +from tinygrad.lazy import vars_from_ast def get_divisors(n, min_div = 1, max_div = 512): if min_div > 1: yield 1 @@ -66,7 +66,7 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, buf # don't optimize variable shapes choice = "BASELINE" else: - var_vals = {k:k.min for k in var_vals_from_ast(k.ast)} + var_vals = {k:k.min for k in vars_from_ast(k.ast)} # get baseline def get_baseline(): k = create_k() diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 34b8d03cf6..d2f6ebfa6d 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -31,12 +31,12 @@ class TinyJit: assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT" if self.cnt >= 2: try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"] - except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor]) + except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor]) if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key)) for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items(): assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}" # NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly - if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}" + if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}" self.jit_cache[j][1][i] = input_rawbuffers[input_name][0] for j in self.updatable_entries.keys(): for k in self.jit_cache[j][2].keys(): @@ -55,7 +55,7 @@ class TinyJit: for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] for i,a in enumerate(cache[1]): if a in [v[0] for v in input_rawbuffers.values()]: - self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0] + self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0] self.updatable_entries[j_].append(i) for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i) #if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e14428176d..c073f13ae9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp from weakref import ref, WeakSet, WeakValueDictionary import numpy as np -from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts +from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import Variable, sint @@ -64,7 +64,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: replacements:Dict[LazyBuffer, LazyOp] = {} base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()]) for x in op.buffers: - st = x.st.simplify() + st = x.st.simplify().unbind() if x.base in base_bufs: replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st)) elif not x.realized and x.base.op.op == LoadOps.CONST: @@ -79,29 +79,28 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast( def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps], [])) +def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], [])) lazycache: WeakValueDictionary = WeakValueDictionary() -def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], base:Optional[LazyBuffer]=None): +def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): # fromcpu aren't cached - if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals, base=base) + if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base) # wop is the deduping key. i feel this used to compare more deeply - wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys())), ref(base) if base else None) + wop = (device, dtype, optype, ref(op), ref(base) if base else None) if wop in lazycache: for x in op.buffers: x.children.add(lazycache[wop]) return lazycache[wop] - lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals, base=base) + lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base) return ret UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} class LazyBuffer: __deletable__ = ('op',) - def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None): + def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None): self.st: ShapeTracker = st - self._var_vals: Dict[Variable, int] = var_vals self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype self._realized: Optional[RawBuffer] = src self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized @@ -117,9 +116,6 @@ class LazyBuffer: if base: base.views.add(self) else: assert st.contiguous, "unbased LazyBuffers must be contiguous" - @property - def var_vals_key(self): return tuple(sorted(self.var_vals.keys())) - @property def base(self): return self._base if self._base is not None else self @@ -137,18 +133,12 @@ class LazyBuffer: def dtype(self, val): assert self._base is None, "no setting dtype of based LazyBuffers" self._dtype = val - @property - def var_vals(self): return self.base._var_vals - @var_vals.setter - def var_vals(self, val): - assert self._base is None, "no setting var_vals of based LazyBuffers" - self._var_vals = val def __repr__(self): return f"" @property def key(self): - if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key) - return (self.dtype, self.op.op, self.st, self.var_vals_key) + if self.realized: return (self.dtype, self.realized.key, self.st) + return (self.dtype, self.op.op, self.st) def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} @@ -171,22 +161,21 @@ class LazyBuffer: if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape) elif self.optype is ReduceOps: op = _ast_reduceops(op) - # realize the past and exec the AST + # schedule the past ret = [] for x in op.buffers: ret += x.schedule(seen) - # TODO: this belongs in the schedule in some way - self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) + var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) # run the ast and log the op op, base_bufs = _replace_bufferops(op) - return ret + [ScheduleItem(op, self, tuple(base_bufs))] + return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})] # *** creation/special ops *** @staticmethod - def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {}) + def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: + return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) # create a constant with the shape and dtype of self def const(self, val:Union[float, int]) -> LazyBuffer: @@ -203,13 +192,13 @@ class LazyBuffer: if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const(): # this will turn into nothing, it's based and a copy # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops - return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base) + return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base) # real contiguous, this will turn into a UnaryOps.NOOP - return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals) + return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self) @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) + return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x)) # *** elementwise ops *** @@ -238,14 +227,15 @@ class LazyBuffer: # remove the buffers from any (childless) BinaryOps that feed into this srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore - return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) + return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype) # *** reduce ops *** def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) - return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals) + unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype) def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. @@ -264,20 +254,10 @@ class LazyBuffer: root = get_movementroot(self) if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape): return root.reshape(st.shape) - return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals, base=self.base) + return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base) def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: if self.shape == arg: return self - new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int)) - if new_nodes and all(isinstance(s, int) for s in self.shape): - # reshape from all int shape into shape with a variable, update the variable value - assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape" - new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints) - # TODO: is it okay to set these var_vals on the base? - if new_var not in self.var_vals: - assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]" - self.var_vals[new_var] = new_val - else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" if not self.realized and self.op.op == MovementOps.RESHAPE: assert isinstance(self.op.src[0], LazyBuffer) self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 91a54e2863..5c17c7a3b3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -45,7 +45,7 @@ class ScheduleItem: ast: LazyOp out: LazyBuffer inputs: Tuple[LazyBuffer, ...] - # TODO: move var_vals here + var_vals: Dict[Variable, int] class LazyOp: __slots__ = "op", "src", "arg", "buffers", "__weakref__" @@ -214,7 +214,7 @@ class ASTRunner: class Compiled: def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor): self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec - self.method_cache: Dict[Any, ASTRunner] = {} + self.method_cache: Dict[LazyOp, ASTRunner] = {} def to_program(self, k): k.linearize() @@ -243,6 +243,11 @@ class Compiled: # all the rawbuffers rawbuffers = [output.realized] + [x.realized for x in inputs] + # extract real vars used in ast + from tinygrad.lazy import vars_from_ast + ast_vars = vars_from_ast(ast) + assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}" + # compilation time def get_program(): from tinygrad.codegen.linearizer import Linearizer @@ -262,8 +267,5 @@ class Compiled: if prg.name == getenv("PRINT_PRG", ''): print(prg.prg) - # extract real var vals - from tinygrad.lazy import var_vals_from_ast - real_var_vals = var_vals_from_ast(ast) - prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in real_var_vals}) + prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars}) return output.realized diff --git a/tinygrad/realize.py b/tinygrad/realize.py index a28b332bd0..0e331e635c 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -24,7 +24,7 @@ def run_schedule(schedule:List[ScheduleItem]): for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs) else: - si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.out.var_vals, **si.out._device_extra_args()) + si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args()) del si.out.op for v in si.out.views: del v.op assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}" diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index dd2e108c53..d20a50d795 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,8 +1,8 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations -import functools +import functools, operator from dataclasses import dataclass -from typing import Tuple, List, Optional, cast +from typing import Tuple, List, Optional, Dict, cast from tinygrad.ops import MovementOps from tinygrad.helpers import prod, DEBUG, dedup from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint @@ -82,13 +82,18 @@ class ShapeTracker: # this is the real size (ish) def size(self): return self.views[-1].size() - def var_vals(self) -> List[Variable]: - ret = [] - for v in self.views: - for x in v.shape+v.strides+(v.offset,): - if isinstance(x, Node): - ret += x.vars() - return dedup(ret) + def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], [])) + + @property + def var_vals(self) -> Dict[Variable, int]: + ret:Dict[Variable, int] = {} + for v in self.vars(): + var, val = v.unbind() + assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}" + ret[var] = val + return ret + + def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views)) def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: to_apply:List[Tuple[MovementOps, Tuple]] = [] diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 34d6d78577..424214ed9a 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -33,6 +33,7 @@ class Node: yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]])) # substitute Variables with the values in var_vals def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__) + def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") @@ -149,6 +150,14 @@ class Variable(Node): def __init__(self, expr:Optional[str], nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax + self.val:Optional[int] = None + def bind(self, val): + assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}" + self.val = val + return self + def unbind(self) -> Tuple[Variable, int]: + assert self.val is not None, f"cannot unbind {self}" + return Variable(self.expr, self.min, self.max), self.val def vars(self): return [self] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self @@ -157,6 +166,9 @@ class NumNode(Node): assert isinstance(num, int), f"{num} is not an int" self.b:int = num self.min, self.max = num, num + def bind(self, val): + assert self.b == val, f"cannot bind {val} to {self}" + return self def __eq__(self, other): return self.b == other def __hash__(self): return self.hash # needed with __eq__ override def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self @@ -324,7 +336,7 @@ sint = Union[Node, int] VariableOrNum = Union[Variable, NumNode] render_python: Dict[Type, Callable] = { - Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"), + Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"), NumNode: lambda self,ops,ctx: f"{self.b}", MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 77b30aa23d..960db7ca75 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -1,9 +1,9 @@ from __future__ import annotations -import functools +import functools, operator from dataclasses import dataclass -from typing import Tuple, List, Optional -from tinygrad.helpers import prod, all_int -from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint +from typing import Tuple, List, Optional, Dict, cast +from tinygrad.helpers import prod, all_int, dedup +from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint @functools.lru_cache(maxsize=None) def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: @@ -33,6 +33,18 @@ class View: @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0]) + def vars(self) -> List[Variable]: + flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() + return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], [])) + + def unbind(self) -> View: + unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None} + new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape]) + new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides]) + new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars) + new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None + return View.create(new_shape, new_strides, new_offset, new_mask) + # MovementOps live here now def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: @@ -88,8 +100,13 @@ class View: if self.shape == new_shape: return self assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}" - # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done - assert prod(self.shape) == prod(new_shape) if all_int(self.shape + new_shape) else True, f"can't reshape {self.shape=} -> {new_shape=}" + # check for the same size + if all_int(self.shape): + if all_int(new_shape): + assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}" + else: + assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" + assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}" # after the asserts, it's okay to check contiguous if self.contiguous: return View.create(new_shape)