mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Variable.bind newer (#2017)
* Variable.bind attempt 2 * ShapeTracker.unbind * fix llama * fix types * test case * View.vars cleanup * include mask in symbolic source * mask can be sint * st.unbind in bufferops * assert ast contain free Variable only * cleanup * conservative unbinding reduce op arg * move reduceop unbind * fix llama JIT arg behavior
This commit is contained in:
@@ -112,11 +112,9 @@ class Transformer:
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
_bsz, seqlen = tokens.shape
|
||||
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos)
|
||||
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
|
||||
pos.lazydata.var_vals[start_pos_var] = start_pos
|
||||
for k,v in self.kv_caches.items():
|
||||
v.lazydata.var_vals[start_pos_var] = start_pos
|
||||
self.kv_caches[k] = v.reshape(v.shape[0], start_pos_var, v.shape[2], v.shape[3])
|
||||
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
|
||||
return logit_or_softmax
|
||||
|
||||
@@ -18,7 +18,7 @@ from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
JIT = getenv("JIT", 0 if CI else Device.DEFAULT in JIT_SUPPORTED_DEVICE)
|
||||
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
@@ -82,7 +82,7 @@ class Attention:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})"
|
||||
assert start_pos == (cache_k.shape[1].val if isinstance(cache_k.shape[1], Variable) else cache_k.shape[1]) == (cache_v.shape[1].val if isinstance(cache_v.shape[1], Variable) else cache_v.shape[1]), f"cache has wrong shape, {start_pos=}, {cache_k.shape[1]=}, {cache_v.shape[1]=}"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
|
||||
@@ -117,12 +117,9 @@ class TransformerBlock:
|
||||
bsz, seqlen, _ = x.shape
|
||||
if JIT and mask is None:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
pos = Variable("pos", 1, 1024)
|
||||
pos = Variable("pos", 1, 1024).bind(start_pos)
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
|
||||
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
|
||||
cache_k.lazydata.var_vals[pos] = start_pos
|
||||
cache_v.lazydata.var_vals[pos] = start_pos
|
||||
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
|
||||
h = x + output
|
||||
@@ -150,14 +147,12 @@ class Transformer:
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
_bsz, seqlen = tokens.shape
|
||||
if seqlen == 1 and start_pos > 0 and JIT:
|
||||
pos = Variable("pos", 1, 1024)
|
||||
pos = Variable("pos", 1, 1024).bind(start_pos)
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.var_vals[pos] = start_pos
|
||||
h = self.tok_embeddings_jitted(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos: start_pos})
|
||||
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos.unbind()[0]: start_pos})
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess_jitted(h, temperature)
|
||||
else:
|
||||
|
||||
@@ -43,7 +43,7 @@ class ATan2(Function):
|
||||
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
||||
self.a, self.b = a, b
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
|
||||
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype))
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
|
||||
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
|
||||
@@ -11,8 +11,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
@@ -20,12 +20,12 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_reshape_inside_plus1(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
def f(a, jit=False, jit_ctx=None):
|
||||
if jit: a = a.reshape(3, vi)
|
||||
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
|
||||
return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
@@ -35,8 +35,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
@@ -47,8 +47,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
@@ -63,7 +63,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
return s
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
@@ -74,8 +74,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
@@ -87,8 +87,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
@@ -99,8 +99,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
@@ -111,10 +111,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
@@ -125,10 +125,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
@@ -139,10 +139,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
@@ -153,13 +153,14 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_jit_symbolic_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
b = Tensor.rand(3, i).reshape(3, vi)
|
||||
c = add(a, b)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi)
|
||||
vi2 = Variable("i", 1, 10).bind(7)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi2)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi2)
|
||||
with self.assertRaises(AssertionError):
|
||||
add(a, bad)
|
||||
|
||||
@@ -167,11 +168,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic.lazydata.var_vals[vi] = i
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -10,8 +10,8 @@ import numpy as np
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
@@ -19,8 +19,8 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
@@ -29,26 +29,18 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul_same_var_different_val(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, 4)
|
||||
b = Tensor.rand(7, 5)
|
||||
with self.assertRaises(AssertionError):
|
||||
f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
|
||||
def test_attention(self, dropout_p=0.0):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
@@ -65,8 +57,8 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
@@ -75,8 +67,8 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
@@ -85,10 +77,10 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
@@ -97,10 +89,10 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
@@ -109,10 +101,10 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
@@ -120,11 +112,10 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_shrink(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic.lazydata.var_vals[vi] = i
|
||||
symbolic = symbolic.numpy()
|
||||
expected = a.shrink(((3,5),(i,i+2))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -22,107 +22,130 @@ class TestSymbolic(unittest.TestCase):
|
||||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
|
||||
def test_cat_strides(self):
|
||||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
def test_cat_dim0_strides(self):
|
||||
i = Variable("i", 1, 5).bind(3)
|
||||
j = Variable("j", 1, 5).bind(3)
|
||||
k = Variable("k", 1, 5).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (i+j+k, 4)
|
||||
assert st.real_strides() == (4, 1)
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (2*i+3, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
def test_cat_dim1_strides(self):
|
||||
i = Variable("i", 1, 5).bind(4)
|
||||
j = Variable("j", 1, 5).bind(4)
|
||||
k = Variable("k", 1, 5).bind(4)
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
|
||||
class TestSymbolicVarVals(unittest.TestCase):
|
||||
def test_var_vals_empty(self):
|
||||
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
|
||||
|
||||
def test_var_vals_shape(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
|
||||
|
||||
def test_var_vals_offset(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
|
||||
assert st.real_offset() == x * 3
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
|
||||
def test_var_vals_mask(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
|
||||
st = ShapeTracker(views=(view,))
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
|
||||
def test_var_vals_complex(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
y = Variable("y", 1, 100).bind(4)
|
||||
z = Variable("z", 1, 100).bind(5)
|
||||
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
|
||||
assert st.real_offset() == y * z
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
|
||||
|
||||
def test_shrink_reshape(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
|
||||
st = st.reshape((3*4*3,))
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
|
||||
class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
def test_view_unbind(self):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(3)
|
||||
assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4))
|
||||
|
||||
def test_reshape_unbind(self):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(bv, 4)
|
||||
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),))
|
||||
|
||||
def test_shrink_unbind(self):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(2)
|
||||
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
|
||||
assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.var_vals[vi] == i
|
||||
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
|
||||
assert t.shape == (vi, 2, 3)
|
||||
assert t.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_symbols_reshape_ints(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.var_vals == {vi: i}
|
||||
t = t.reshape(i, 4)
|
||||
assert t.shape == (i, 4)
|
||||
assert t.lazydata.var_vals == {vi: i}
|
||||
|
||||
def test_reshape_reuse_var_same_value_ok(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
assert a.lazydata.var_vals[vi] == i
|
||||
assert b.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_reuse_var_different_value_ok(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 2)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
# a and b have different values of vi
|
||||
assert a.lazydata.var_vals[vi] == 2 * i
|
||||
assert b.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values
|
||||
vi = Variable("i", 1, 10).bind(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
||||
|
||||
def test_two_symbol_reshape(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
for i in range(1, 6):
|
||||
for j in range(1, 6):
|
||||
t1 = Tensor.rand(i, 5).reshape(vi, 5)
|
||||
t2 = Tensor.rand(5, j).reshape(5, vj)
|
||||
t = t1@t2
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
vj = Variable("j", 1, 5).bind(j)
|
||||
t = Tensor.rand(i, j).reshape(vi, vj)
|
||||
assert t.shape == (vi, vj)
|
||||
t = t.reshape(1, vi*vj)
|
||||
assert t.shape == (1, vi*vj)
|
||||
# NOTE: this is currently not allowed
|
||||
# t = t.reshape(1, vi*vj)
|
||||
# assert t.shape == (1, vi*vj)
|
||||
t = t.reshape(vj, vi)
|
||||
assert t.shape == (vj, vi)
|
||||
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
# TODO: enfore expand only into bound variables
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
a = Tensor([[1], [2], [3]]).expand((3, vi))
|
||||
assert a.shape == (3, vi)
|
||||
assert a.lazydata.var_vals == {}
|
||||
a = a.reshape(3, vi, 1).expand((3, vi, vj))
|
||||
assert a.shape == (3, vi, vj)
|
||||
assert a.lazydata.var_vals == {}
|
||||
|
||||
def test_plus_expands_constant(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
||||
@@ -146,23 +169,5 @@ class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
|
||||
|
||||
class TestShapeTrackerVarVals(unittest.TestCase):
|
||||
def test_reshape_reshape_updates_var_vals(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
|
||||
assert t.lazydata.var_vals == {vi: 4, vj: 3}
|
||||
|
||||
def test_lazy_check_var_vals(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
b = Tensor.rand(5, 6).reshape(vi, 6)
|
||||
assert a.lazydata.var_vals == {vi: 4}
|
||||
assert b.lazydata.var_vals == {vi: 5}
|
||||
c = a@b
|
||||
# shapetracker works with symbolic shape and doesn't check the underlying variable values
|
||||
assert c.shape == (3, 6)
|
||||
assert c.lazydata.var_vals == {vi: 4}
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -11,7 +11,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
from tinygrad.codegen.optimizer import OptimizedKernel
|
||||
from tinygrad.codegen.kernel import LocalBuffer
|
||||
from tinygrad.lazy import var_vals_from_ast
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.features.image import to_image_idx
|
||||
|
||||
# bottom ones are asm only
|
||||
@@ -164,7 +164,7 @@ class Linearizer(OptimizedKernel):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
|
||||
# add var vals
|
||||
for var in sorted(var_vals_from_ast(self.ast), key=lambda k: k.key):
|
||||
for var in sorted(vars_from_ast(self.ast), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable
|
||||
import time
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.helpers import DEBUG, prod, getenv
|
||||
from tinygrad.lazy import var_vals_from_ast
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
|
||||
def get_divisors(n, min_div = 1, max_div = 512):
|
||||
if min_div > 1: yield 1
|
||||
@@ -66,7 +66,7 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, buf
|
||||
# don't optimize variable shapes
|
||||
choice = "BASELINE"
|
||||
else:
|
||||
var_vals = {k:k.min for k in var_vals_from_ast(k.ast)}
|
||||
var_vals = {k:k.min for k in vars_from_ast(k.ast)}
|
||||
# get baseline
|
||||
def get_baseline():
|
||||
k = create_k()
|
||||
|
||||
@@ -31,12 +31,12 @@ class TinyJit:
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
|
||||
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}"
|
||||
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||
for j in self.updatable_entries.keys():
|
||||
for k in self.jit_cache[j][2].keys():
|
||||
@@ -55,7 +55,7 @@ class TinyJit:
|
||||
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
|
||||
for i,a in enumerate(cache[1]):
|
||||
if a in [v[0] for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
self.updatable_entries[j_].append(i)
|
||||
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
|
||||
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
|
||||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -64,7 +64,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
replacements:Dict[LazyBuffer, LazyOp] = {}
|
||||
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
|
||||
for x in op.buffers:
|
||||
st = x.st.simplify()
|
||||
st = x.st.simplify().unbind()
|
||||
if x.base in base_bufs:
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
|
||||
elif not x.realized and x.base.op.op == LoadOps.CONST:
|
||||
@@ -79,29 +79,28 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps], []))
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
|
||||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], base:Optional[LazyBuffer]=None):
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals, base=base)
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
|
||||
# wop is the deduping key. i feel this used to compare more deeply
|
||||
wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys())), ref(base) if base else None)
|
||||
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
|
||||
if wop in lazycache:
|
||||
for x in op.buffers: x.children.add(lazycache[wop])
|
||||
return lazycache[wop]
|
||||
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals, base=base)
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
return ret
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
|
||||
self.st: ShapeTracker = st
|
||||
self._var_vals: Dict[Variable, int] = var_vals
|
||||
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
|
||||
self._realized: Optional[RawBuffer] = src
|
||||
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
|
||||
@@ -117,9 +116,6 @@ class LazyBuffer:
|
||||
if base: base.views.add(self)
|
||||
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
|
||||
|
||||
@property
|
||||
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
|
||||
|
||||
@property
|
||||
def base(self): return self._base if self._base is not None else self
|
||||
|
||||
@@ -137,18 +133,12 @@ class LazyBuffer:
|
||||
def dtype(self, val):
|
||||
assert self._base is None, "no setting dtype of based LazyBuffers"
|
||||
self._dtype = val
|
||||
@property
|
||||
def var_vals(self): return self.base._var_vals
|
||||
@var_vals.setter
|
||||
def var_vals(self, val):
|
||||
assert self._base is None, "no setting var_vals of based LazyBuffers"
|
||||
self._var_vals = val
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
|
||||
return (self.dtype, self.op.op, self.st, self.var_vals_key)
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st)
|
||||
return (self.dtype, self.op.op, self.st)
|
||||
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
|
||||
|
||||
@@ -171,22 +161,21 @@ class LazyBuffer:
|
||||
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
elif self.optype is ReduceOps: op = _ast_reduceops(op)
|
||||
|
||||
# realize the past and exec the AST
|
||||
# schedule the past
|
||||
ret = []
|
||||
for x in op.buffers: ret += x.schedule(seen)
|
||||
|
||||
# TODO: this belongs in the schedule in some way
|
||||
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
return ret + [ScheduleItem(op, self, tuple(base_bufs))]
|
||||
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
|
||||
|
||||
# *** creation/special ops ***
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {})
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
@@ -203,13 +192,13 @@ class LazyBuffer:
|
||||
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
|
||||
# this will turn into nothing, it's based and a copy
|
||||
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base)
|
||||
# real contiguous, this will turn into a UnaryOps.NOOP
|
||||
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
|
||||
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
# *** elementwise ops ***
|
||||
|
||||
@@ -238,14 +227,15 @@ class LazyBuffer:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
|
||||
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
@@ -264,20 +254,10 @@ class LazyBuffer:
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals, base=self.base)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
# reshape from all int shape into shape with a variable, update the variable value
|
||||
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
|
||||
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
|
||||
# TODO: is it okay to set these var_vals on the base?
|
||||
if new_var not in self.var_vals:
|
||||
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
|
||||
self.var_vals[new_var] = new_val
|
||||
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
assert isinstance(self.op.src[0], LazyBuffer)
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
|
||||
@@ -45,7 +45,7 @@ class ScheduleItem:
|
||||
ast: LazyOp
|
||||
out: LazyBuffer
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
# TODO: move var_vals here
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
class LazyOp:
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
@@ -214,7 +214,7 @@ class ASTRunner:
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec
|
||||
self.method_cache: Dict[Any, ASTRunner] = {}
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
|
||||
def to_program(self, k):
|
||||
k.linearize()
|
||||
@@ -243,6 +243,11 @@ class Compiled:
|
||||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
# extract real vars used in ast
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
ast_vars = vars_from_ast(ast)
|
||||
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
|
||||
|
||||
# compilation time
|
||||
def get_program():
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
@@ -262,8 +267,5 @@ class Compiled:
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
# extract real var vals
|
||||
from tinygrad.lazy import var_vals_from_ast
|
||||
real_var_vals = var_vals_from_ast(ast)
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in real_var_vals})
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
|
||||
return output.realized
|
||||
|
||||
@@ -24,7 +24,7 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
|
||||
else:
|
||||
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.out.var_vals, **si.out._device_extra_args())
|
||||
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, cast
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.ops import MovementOps
|
||||
from tinygrad.helpers import prod, DEBUG, dedup
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
@@ -82,13 +82,18 @@ class ShapeTracker:
|
||||
# this is the real size (ish)
|
||||
def size(self): return self.views[-1].size()
|
||||
|
||||
def var_vals(self) -> List[Variable]:
|
||||
ret = []
|
||||
for v in self.views:
|
||||
for x in v.shape+v.strides+(v.offset,):
|
||||
if isinstance(x, Node):
|
||||
ret += x.vars()
|
||||
return dedup(ret)
|
||||
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
|
||||
|
||||
@property
|
||||
def var_vals(self) -> Dict[Variable, int]:
|
||||
ret:Dict[Variable, int] = {}
|
||||
for v in self.vars():
|
||||
var, val = v.unbind()
|
||||
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
|
||||
ret[var] = val
|
||||
return ret
|
||||
|
||||
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
|
||||
|
||||
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
|
||||
@@ -33,6 +33,7 @@ class Node:
|
||||
yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
||||
# substitute Variables with the values in var_vals
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@@ -149,6 +150,14 @@ class Variable(Node):
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self.val:Optional[int] = None
|
||||
def bind(self, val):
|
||||
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self.val = val
|
||||
return self
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
return Variable(self.expr, self.min, self.max), self.val
|
||||
def vars(self): return [self]
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
|
||||
|
||||
@@ -157,6 +166,9 @@ class NumNode(Node):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def bind(self, val):
|
||||
assert self.b == val, f"cannot bind {val} to {self}"
|
||||
return self
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
|
||||
@@ -324,7 +336,7 @@ sint = Union[Node, int]
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional
|
||||
from tinygrad.helpers import prod, all_int
|
||||
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.helpers import prod, all_int, dedup
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -33,6 +33,18 @@ class View:
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
|
||||
|
||||
def vars(self) -> List[Variable]:
|
||||
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
||||
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
|
||||
|
||||
def unbind(self) -> View:
|
||||
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
|
||||
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
||||
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
||||
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
||||
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
# MovementOps live here now
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
||||
@@ -88,8 +100,13 @@ class View:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
assert prod(self.shape) == prod(new_shape) if all_int(self.shape + new_shape) else True, f"can't reshape {self.shape=} -> {new_shape=}"
|
||||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
if all_int(new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
else:
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
Reference in New Issue
Block a user