Variable.bind newer (#2017)

* Variable.bind attempt 2

* ShapeTracker.unbind

* fix llama

* fix types

* test case

* View.vars cleanup

* include mask in symbolic source

* mask can be sint

* st.unbind in bufferops

* assert ast contain free Variable only

* cleanup

* conservative unbinding reduce op arg

* move reduceop unbind

* fix llama JIT arg behavior
This commit is contained in:
chenyu
2023-10-10 13:03:01 -04:00
committed by GitHub
parent 71d93ffd79
commit e2b83f1b42
15 changed files with 204 additions and 199 deletions

View File

@@ -112,11 +112,9 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
_bsz, seqlen = tokens.shape
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos)
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
pos.lazydata.var_vals[start_pos_var] = start_pos
for k,v in self.kv_caches.items():
v.lazydata.var_vals[start_pos_var] = start_pos
self.kv_caches[k] = v.reshape(v.shape[0], start_pos_var, v.shape[2], v.shape[3])
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
return logit_or_softmax

View File

@@ -18,7 +18,7 @@ from tinygrad.ops import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable, sym_infer
JIT = getenv("JIT", 0 if CI else Device.DEFAULT in JIT_SUPPORTED_DEVICE)
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
@@ -82,7 +82,7 @@ class Attention:
keys, values = xk, xv
else:
assert cache_k is not None and cache_v is not None, "no cache"
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})"
assert start_pos == (cache_k.shape[1].val if isinstance(cache_k.shape[1], Variable) else cache_k.shape[1]) == (cache_v.shape[1].val if isinstance(cache_v.shape[1], Variable) else cache_v.shape[1]), f"cache has wrong shape, {start_pos=}, {cache_k.shape[1]=}, {cache_v.shape[1]=}"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
@@ -117,12 +117,9 @@ class TransformerBlock:
bsz, seqlen, _ = x.shape
if JIT and mask is None:
assert cache_k is not None and cache_v is not None, "no cache"
pos = Variable("pos", 1, 1024)
pos = Variable("pos", 1, 1024).bind(start_pos)
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
cache_k.lazydata.var_vals[pos] = start_pos
cache_v.lazydata.var_vals[pos] = start_pos
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
h = x + output
@@ -150,14 +147,12 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
_bsz, seqlen = tokens.shape
if seqlen == 1 and start_pos > 0 and JIT:
pos = Variable("pos", 1, 1024)
pos = Variable("pos", 1, 1024).bind(start_pos)
# get only the part of freqs_cis that we are using.
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
freqs_cis.lazydata.var_vals[pos] = start_pos
h = self.tok_embeddings_jitted(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos: start_pos})
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos.unbind()[0]: start_pos})
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess_jitted(h, temperature)
else:

View File

@@ -43,7 +43,7 @@ class ATan2(Function):
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
self.a, self.b = a, b
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype))
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \

View File

@@ -11,8 +11,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
@@ -20,12 +20,12 @@ class TestSymbolicJit(unittest.TestCase):
assert len(jf.jit_cache) == 1
def test_reshape_inside_plus1(self):
vi = Variable("i", 1, 10)
def f(a, jit=False, jit_ctx=None):
if jit: a = a.reshape(3, vi)
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, i)
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
expected = f(a).numpy()
@@ -35,8 +35,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
@@ -47,8 +47,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_matmul(self):
def f(a, b): return (a@b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
@@ -63,7 +63,7 @@ class TestSymbolicJit(unittest.TestCase):
return s
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
@@ -74,8 +74,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
@@ -87,8 +87,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
@@ -99,8 +99,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
@@ -111,10 +111,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
@@ -125,10 +125,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
@@ -139,10 +139,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
@@ -153,13 +153,14 @@ class TestSymbolicJit(unittest.TestCase):
def test_jit_symbolic_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
b = Tensor.rand(3, i).reshape(3, vi)
c = add(a, b)
a = Tensor.rand(3, 7).reshape(3, vi)
bad = Tensor.rand(4, 7).reshape(4, vi)
vi2 = Variable("i", 1, 10).bind(7)
a = Tensor.rand(3, 7).reshape(3, vi2)
bad = Tensor.rand(4, 7).reshape(4, vi2)
with self.assertRaises(AssertionError):
add(a, bad)
@@ -167,11 +168,10 @@ class TestSymbolicJit(unittest.TestCase):
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.var_vals[vi] = i
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -10,8 +10,8 @@ import numpy as np
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
@@ -19,8 +19,8 @@ class TestSymbolicOps(unittest.TestCase):
def test_add(self):
def f(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
@@ -29,26 +29,18 @@ class TestSymbolicOps(unittest.TestCase):
def test_matmul(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul_same_var_different_val(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4)
b = Tensor.rand(7, 5)
with self.assertRaises(AssertionError):
f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
def test_attention(self, dropout_p=0.0):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
@@ -65,8 +57,8 @@ class TestSymbolicOps(unittest.TestCase):
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
@@ -75,8 +67,8 @@ class TestSymbolicOps(unittest.TestCase):
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy()
@@ -85,10 +77,10 @@ class TestSymbolicOps(unittest.TestCase):
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
@@ -97,10 +89,10 @@ class TestSymbolicOps(unittest.TestCase):
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
@@ -109,10 +101,10 @@ class TestSymbolicOps(unittest.TestCase):
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
@@ -120,11 +112,10 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_shrink(self):
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.var_vals[vi] = i
symbolic = symbolic.numpy()
expected = a.shrink(((3,5),(i,i+2))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -22,107 +22,130 @@ class TestSymbolic(unittest.TestCase):
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
def test_cat_strides(self):
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)
def test_cat_dim1_strides(self):
i = Variable("i", 1, 5).bind(4)
j = Variable("j", 1, 5).bind(4)
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
class TestSymbolicVarVals(unittest.TestCase):
def test_var_vals_empty(self):
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
def test_var_vals_shape(self):
x = Variable("x", 1, 100).bind(3)
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_offset(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
assert st.real_offset() == x * 3
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_mask(self):
x = Variable("x", 1, 100).bind(3)
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
st = ShapeTracker(views=(view,))
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_complex(self):
x = Variable("x", 1, 100).bind(3)
y = Variable("y", 1, 100).bind(4)
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
assert st.real_offset() == y * z
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
def test_shrink_reshape(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
st = st.reshape((3*4*3,))
assert st.var_vals == {Variable("x", 1, 100): 3}
class TestShapeTrackerUnbind(unittest.TestCase):
def test_view_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4))
def test_reshape_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
t = Tensor.rand(3, 4).reshape(bv, 4)
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),))
def test_shrink_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(2)
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.var_vals[vi] == i
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
assert t.shape == (vi, 2, 3)
assert t.lazydata.var_vals[vi] == i
def test_reshape_symbols_reshape_ints(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.var_vals == {vi: i}
t = t.reshape(i, 4)
assert t.shape == (i, 4)
assert t.lazydata.var_vals == {vi: i}
def test_reshape_reuse_var_same_value_ok(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
a = Tensor.rand(i, 4).reshape(vi, 4)
b = Tensor.rand(i, 3).reshape(vi, 3)
assert a.lazydata.var_vals[vi] == i
assert b.lazydata.var_vals[vi] == i
def test_reshape_reuse_var_different_value_ok(self):
vi = Variable("i", 1, 10)
for i in range(1, 6):
a = Tensor.rand(i, 4).reshape(vi, 2)
b = Tensor.rand(i, 3).reshape(vi, 3)
# a and b have different values of vi
assert a.lazydata.var_vals[vi] == 2 * i
assert b.lazydata.var_vals[vi] == i
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values
vi = Variable("i", 1, 10).bind(4)
with self.assertRaises(AssertionError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
def test_two_symbol_reshape(self):
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
for i in range(1, 6):
for j in range(1, 6):
t1 = Tensor.rand(i, 5).reshape(vi, 5)
t2 = Tensor.rand(5, j).reshape(5, vj)
t = t1@t2
vi = Variable("i", 1, 5).bind(i)
vj = Variable("j", 1, 5).bind(j)
t = Tensor.rand(i, j).reshape(vi, vj)
assert t.shape == (vi, vj)
t = t.reshape(1, vi*vj)
assert t.shape == (1, vi*vj)
# NOTE: this is currently not allowed
# t = t.reshape(1, vi*vj)
# assert t.shape == (1, vi*vj)
t = t.reshape(vj, vi)
assert t.shape == (vj, vi)
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
# TODO: enfore expand only into bound variables
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
assert a.lazydata.var_vals == {}
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
assert a.lazydata.var_vals == {}
def test_plus_expands_constant(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
@@ -146,23 +169,5 @@ class TestSymbolicShapeExpr(unittest.TestCase):
idx, valid = st.expr_idxs(idx)
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
class TestShapeTrackerVarVals(unittest.TestCase):
def test_reshape_reshape_updates_var_vals(self):
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
assert t.lazydata.var_vals == {vi: 4, vj: 3}
def test_lazy_check_var_vals(self):
vi = Variable("i", 1, 5)
a = Tensor.rand(3, 4).reshape(3, vi)
b = Tensor.rand(5, 6).reshape(vi, 6)
assert a.lazydata.var_vals == {vi: 4}
assert b.lazydata.var_vals == {vi: 5}
c = a@b
# shapetracker works with symbolic shape and doesn't check the underlying variable values
assert c.shape == (3, 6)
assert c.lazydata.var_vals == {vi: 4}
if __name__ == '__main__':
unittest.main()

View File

@@ -11,7 +11,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
from tinygrad.codegen.optimizer import OptimizedKernel
from tinygrad.codegen.kernel import LocalBuffer
from tinygrad.lazy import var_vals_from_ast
from tinygrad.lazy import vars_from_ast
from tinygrad.features.image import to_image_idx
# bottom ones are asm only
@@ -164,7 +164,7 @@ class Linearizer(OptimizedKernel):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
# add var vals
for var in sorted(var_vals_from_ast(self.ast), key=lambda k: k.key):
for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers

View File

@@ -2,7 +2,7 @@ from typing import Callable
import time
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.helpers import DEBUG, prod, getenv
from tinygrad.lazy import var_vals_from_ast
from tinygrad.lazy import vars_from_ast
def get_divisors(n, min_div = 1, max_div = 512):
if min_div > 1: yield 1
@@ -66,7 +66,7 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, buf
# don't optimize variable shapes
choice = "BASELINE"
else:
var_vals = {k:k.min for k in var_vals_from_ast(k.ast)}
var_vals = {k:k.min for k in vars_from_ast(k.ast)}
# get baseline
def get_baseline():
k = create_k()

View File

@@ -31,12 +31,12 @@ class TinyJit:
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2:
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor])
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}"
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for j in self.updatable_entries.keys():
for k in self.jit_cache[j][2].keys():
@@ -55,7 +55,7 @@ class TinyJit:
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
for i,a in enumerate(cache[1]):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
self.updatable_entries[j_].append(i)
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local

View File

@@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -64,7 +64,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
st = x.st.simplify()
st = x.st.simplify().unbind()
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
@@ -79,29 +79,28 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps], []))
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], base:Optional[LazyBuffer]=None):
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
# fromcpu aren't cached
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals, base=base)
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys())), ref(base) if base else None)
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
if wop in lazycache:
for x in op.buffers: x.children.add(lazycache[wop])
return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals, base=base)
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self._var_vals: Dict[Variable, int] = var_vals
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
@@ -117,9 +116,6 @@ class LazyBuffer:
if base: base.views.add(self)
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
@property
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
@property
def base(self): return self._base if self._base is not None else self
@@ -137,18 +133,12 @@ class LazyBuffer:
def dtype(self, val):
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
@property
def var_vals(self): return self.base._var_vals
@var_vals.setter
def var_vals(self, val):
assert self._base is None, "no setting var_vals of based LazyBuffers"
self._var_vals = val
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
return (self.dtype, self.op.op, self.st, self.var_vals_key)
if self.realized: return (self.dtype, self.realized.key, self.st)
return (self.dtype, self.op.op, self.st)
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@@ -171,22 +161,21 @@ class LazyBuffer:
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
# realize the past and exec the AST
# schedule the past
ret = []
for x in op.buffers: ret += x.schedule(seen)
# TODO: this belongs in the schedule in some way
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
return ret + [ScheduleItem(op, self, tuple(base_bufs))]
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
# *** creation/special ops ***
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {})
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
@@ -203,13 +192,13 @@ class LazyBuffer:
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
# this will turn into nothing, it's based and a copy
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base)
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base)
# real contiguous, this will turn into a UnaryOps.NOOP
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
# *** elementwise ops ***
@@ -238,14 +227,15 @@ class LazyBuffer:
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
# *** reduce ops ***
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
@@ -264,20 +254,10 @@ class LazyBuffer:
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals, base=self.base)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
# reshape from all int shape into shape with a variable, update the variable value
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
# TODO: is it okay to set these var_vals on the base?
if new_var not in self.var_vals:
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
self.var_vals[new_var] = new_val
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
if not self.realized and self.op.op == MovementOps.RESHAPE:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??

View File

@@ -45,7 +45,7 @@ class ScheduleItem:
ast: LazyOp
out: LazyBuffer
inputs: Tuple[LazyBuffer, ...]
# TODO: move var_vals here
var_vals: Dict[Variable, int]
class LazyOp:
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
@@ -214,7 +214,7 @@ class ASTRunner:
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec
self.method_cache: Dict[Any, ASTRunner] = {}
self.method_cache: Dict[LazyOp, ASTRunner] = {}
def to_program(self, k):
k.linearize()
@@ -243,6 +243,11 @@ class Compiled:
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
# extract real vars used in ast
from tinygrad.lazy import vars_from_ast
ast_vars = vars_from_ast(ast)
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
# compilation time
def get_program():
from tinygrad.codegen.linearizer import Linearizer
@@ -262,8 +267,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
# extract real var vals
from tinygrad.lazy import var_vals_from_ast
real_var_vals = var_vals_from_ast(ast)
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in real_var_vals})
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
return output.realized

View File

@@ -24,7 +24,7 @@ def run_schedule(schedule:List[ScheduleItem]):
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.out.var_vals, **si.out._device_extra_args())
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"

View File

@@ -1,8 +1,8 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, cast
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG, dedup
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
@@ -82,13 +82,18 @@ class ShapeTracker:
# this is the real size (ish)
def size(self): return self.views[-1].size()
def var_vals(self) -> List[Variable]:
ret = []
for v in self.views:
for x in v.shape+v.strides+(v.offset,):
if isinstance(x, Node):
ret += x.vars()
return dedup(ret)
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
@property
def var_vals(self) -> Dict[Variable, int]:
ret:Dict[Variable, int] = {}
for v in self.vars():
var, val = v.unbind()
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
ret[var] = val
return ret
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []

View File

@@ -33,6 +33,7 @@ class Node:
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
# substitute Variables with the values in var_vals
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@@ -149,6 +150,14 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self.val:Optional[int] = None
def bind(self, val):
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self.val = val
return self
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
def vars(self): return [self]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
@@ -157,6 +166,9 @@ class NumNode(Node):
assert isinstance(num, int), f"{num} is not an int"
self.b:int = num
self.min, self.max = num, num
def bind(self, val):
assert self.b == val, f"cannot bind {val} to {self}"
return self
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
@@ -324,7 +336,7 @@ sint = Union[Node, int]
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
import functools
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional
from tinygrad.helpers import prod, all_int
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int, dedup
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
@@ -33,6 +33,18 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
def vars(self) -> List[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
def unbind(self) -> View:
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask)
# MovementOps live here now
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
@@ -88,8 +100,13 @@ class View:
if self.shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
assert prod(self.shape) == prod(new_shape) if all_int(self.shape + new_shape) else True, f"can't reshape {self.shape=} -> {new_shape=}"
# check for the same size
if all_int(self.shape):
if all_int(new_shape):
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
else:
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)