switch symbolic from old to uops, final PR (#6872)

* switch symbolic from old to uops, final PR

* two wrong answers

* not needed resolves

* symbolic ops passes

* symbolic ops passes

* progress

* tests pass (almost)

* fix last test

* fix some tests

* global binding and unbinding

* Revert "global binding and unbinding"

This reverts commit 9456725630.

* that test works now

* vars on uop doesn't recurse

* fix fuzzer

* update

* fix type

* fix gpt, it's UOp now

* ssimplify symbolics
This commit is contained in:
George Hotz
2024-10-04 16:42:27 +08:00
committed by GitHub
parent 738a5794a9
commit f4ec39fe58
13 changed files with 150 additions and 984 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
from tinygrad.ops import UOp
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
@@ -75,9 +76,9 @@ class Transformer:
self.lm_head = Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Union[Tensor,Variable], start_pos:Variable, temperature:float=0.0):
def forward(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0):
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
if isinstance(tokens, Variable):
if isinstance(tokens, UOp):
seqlen = 1
tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
else:
@@ -107,8 +108,8 @@ class Transformer:
ret = (logits / temperature).softmax().multinomial()
return ret.flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
forward = (self.forward_jit if JIT and (isinstance(tokens, Variable) or tokens.shape[1] == 1) else self.forward)
def __call__(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0) -> Tensor:
forward = (self.forward_jit if JIT and (isinstance(tokens, UOp) or tokens.shape[1] == 1) else self.forward)
return forward(tokens, start_pos, temperature)
VOCAB_SIZE = 50257

View File

@@ -71,8 +71,8 @@ class Attention:
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xv
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -172,10 +172,10 @@ class Transformer:
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None and start_pos != 0:
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
# *** helpers ***

View File

@@ -40,6 +40,11 @@ def gt(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr > rng, rng
# NOTE: you have to replace these for this test to pass
from tinygrad.ops import python_alu, BinaryOps
python_alu[BinaryOps.MOD] = lambda x,y: x%y
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
if __name__ == "__main__":
ops = [add_v, div, mul, add_num, mod]
for _ in range(1000):
@@ -65,5 +70,8 @@ if __name__ == "__main__":
rn = 0
for t,r in zip(tape, rngs): rn, _ = t(rn, r)
num = eval(expr.render())
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}"
if num != rn:
unsimplified_num = eval(expr.render(simplify=False))
assert unsimplified_num == rn, "UNSIMPLIFIED MISMATCH!"
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}\n{expr.render(simplify=False)}"
if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}")

View File

@@ -4,25 +4,21 @@ from tinygrad import Tensor, Variable
class TestTensorVariable(unittest.TestCase):
def test_add_tvar(self):
vv = Variable("a", 0, 10)
vv.bind(1)
vv = Variable("a", 0, 10).bind(1)
ret = (Tensor(vv) + 3).item()
assert ret == 4
def test_inner_tvar_node(self):
vv = Variable("w", 0, 10)
vv.bind(2)
vv = Variable("w", 0, 10).bind(2)
ret = Tensor.from_node(vv * 4).item()
assert ret == 8
def test_inner_tvar_mul(self):
vv = Variable("w", 0, 10)
vv.bind(2)
vv = Variable("w", 0, 10).bind(2)
assert (Tensor(3) * vv).item() == 6
def test_inner_tvar_mul_node(self):
vv = Variable("w", 0, 10)
vv.bind(2)
vv = Variable("w", 0, 10).bind(2)
assert (Tensor(3) * (vv * 4)).item() == 24
def test_symbolic_mean(self):
@@ -45,14 +41,10 @@ class TestTensorVariable(unittest.TestCase):
ret = t.mean(axis=1).reshape(2, 1).numpy()
assert np.all(ret == 1)
@unittest.expectedFailure
def test_symbolic_mean_2d_add(self):
add_term = Variable("c", 0, 10)
add_term.bind(1)
vv = Variable("a", 1, 10)
vv.bind(1)
vv2 = Variable("b", 1, 10)
vv2.bind(1)
add_term = Variable("c", 0, 10).bind(1)
vv = Variable("a", 1, 10).bind(1)
vv2 = Variable("b", 1, 10).bind(1)
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
ret = t.mean().item()
assert ret == 1
@@ -65,8 +57,7 @@ class TestTensorVariable(unittest.TestCase):
@unittest.skip("symbolic arange isn't supported")
def test_symbolic_arange(self):
vv = Variable("a", 1, 10)
vv.bind(2)
vv = Variable("a", 1, 10).bind(2)
ret = Tensor.arange(0, vv)
ret.realize()

View File

@@ -1,575 +0,0 @@
#!/usr/bin/env python
import unittest, pickle
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
class TestSymbolicPickle(unittest.TestCase):
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
self.assertEqual(v.render(), s)
self.assertEqual(v.min, n)
self.assertEqual(v.max, m)
def test_ge(self):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a*-1)<-7)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a*-1)<-3)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
def test_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
def test_ge_divides(self):
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_ge_divides_and(self):
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
self.helper_test_variable(expr//4, 0, 0, "0")
def test_lt_factors(self):
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
def test_div_reduction(self):
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
def test_var_becomes_num(self):
self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2")
def test_equality(self):
idx1 = Variable("idx1", 0, 3)
idx2 = Variable("idx2", 0, 3)
assert idx1 == idx1
assert idx1 != idx2
assert idx1*4 == idx1*4
assert idx1*4 != idx1*3
assert idx1*4 != idx1+4
assert idx1*4 != idx2*4
assert idx1+idx2 == idx1+idx2
assert idx1+idx2 == idx2+idx1
assert idx1+idx2 != idx2
assert idx1*idx2 == idx2*idx1
def test_numnode_eq_int(self):
n1 = NumNode(1)
n2 = NumNode(2)
assert n1 == 1
assert n2 == 2
assert n1 != n2
assert hash(n1) == hash(1)
assert hash(n2) == hash(2)
def test_factorize(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
def test_factorize_no_mul(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
def test_neg(self):
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
def test_add_1(self):
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
def test_add_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(1+a)")
def test_sub_1(self):
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
def test_sub_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)")
def test_add_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
def test_sub_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a-a, 0, 0, "0")
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
def test_mul_0(self):
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
def test_mul_1(self):
self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
def test_mul_neg_1(self):
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
def test_mul_2(self):
self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
def test_div_1(self):
self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
def test_mod_1(self):
self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
def test_add_min_max(self):
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
def test_div_remove(self):
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
def test_div_min_max(self):
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
def test_div_neg_min_max(self):
self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
def test_sum_div_remove(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
def test_sum_div_min_max(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
def test_sum_div_mod_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
def test_sum_div_some_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((1+a+b)//4)")
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
def test_sum_div_no_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
def test_mod_to_sub(self):
# This is mod reduction
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(-1+a)")
def test_sum_div_const(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
def test_sum_div_const_big(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self):
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, "(((a*4)+b)<16)")
self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
def test_mul_mod_large(self):
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
def test_mul_mod_small(self):
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
def test_mod_mod(self):
self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)")
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
def test_mul_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
def test_mul_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
def test_distribute_mul(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)")
def test_mod_mul_sum(self):
self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
def test_sum_0(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
def test_mod_remove(self):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
def test_big_mod(self):
# NOTE: we no longer support negative variables
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
def test_lt_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
def test_lt_sum_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
def test_lt_simple_factor(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
"(((a*3)+(b*3))<4)")
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
"(((a*3)+(b*2)+(c*4))<2)")
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
"(((a*3)+(b*2)+(c*4))<1)")
def test_and_fold(self):
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
def test_and_remove(self):
self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
def test_mod_factor_negative(self):
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
def test_div_cancel(self):
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
def test_mod_cancel(self):
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
def test_add_div(self):
# careful about the lower bounds and upper bounds
self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "(((2+a)//4)+-1)")
self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "(((3+a)//4)+-1)")
self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)")
self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((1+a)//4)")
self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((2+a)//4)")
self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((3+a)//4)")
self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)")
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((1+a)//4)+1)")
def test_mul_div_factor_mul(self):
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
def test_mul_div_factor_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
def test_sum_div_partial_remove(self):
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
def test_div_neg_cancel(self):
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
def test_sum_div_big_const(self):
gidx0 = Variable("gidx0", 0, 24)
self.helper_test_variable((gidx0+19)//20, 0, 2, "((19+gidx0)//20)")
self.helper_test_variable((gidx0+20)//20, 1, 2, "((gidx0//20)+1)")
self.helper_test_variable((gidx0+21)//20, 1, 2, "(((1+gidx0)//20)+1)")
def test_sum_div_complex1(self):
gidx0 = Variable("gidx0", 0, 24)
gidx1 = Variable("gidx1", 0, 1)
gidx2 = Variable("gidx2", 0, 255)
lidx0 = Variable("lidx0", 0, 1)
lidx1 = Variable("lidx1", 0, 15)
lidx2 = Variable("lidx2", 0, 3)
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx1*8)+(gidx2*32)+(lidx0*16))")
def test_sum_div_complex2(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 1)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0*4+lidx2*2+1)//10, 0, 3, "(((gidx0*2)+lidx2)//5)")
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//10, 0, 3, "(((gidx0*2)+lidx2)//5)")
self.helper_test_variable((gidx0*2+lidx2)//10, 0, 1, "(gidx0//5)")
def test_sum_div_complex3(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
def test_sum_mul_distribute(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
def test_variable_divmod(self):
start_pos = Variable("start_pos", 0, 127)
v = start_pos + 1
idx0 = Variable("idx0", 0, 2)
idx1 = Variable("idx1", 0, start_pos)
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)
#MIN, MAX = -10, 10
MIN, MAX = 0, 10
# one number
for i in range(MIN, MAX):
v = f(NumNode(i))
#print(i, f(i), v.min, v.max)
self.assertEqual(v.min, v.max)
self.assertEqual(v.min, f(i))
for kmin in range(MIN, MAX):
for kmax in range(MIN, MAX):
if kmin > kmax: continue
v = f(Variable("tmp", kmin, kmax))
values = [f(rv) for rv in range(kmin, kmax+1)]
# the min and max may not be exact
self.assertLessEqual(v.min, min(values))
self.assertGreaterEqual(v.max, max(values))
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
class TestSymbolicVars(unittest.TestCase):
def test_simple(self):
z = NumNode(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == set()
assert a.vars() == a.vars() == {a}
m = MulNode(a, 3)
assert m.vars() == {a}
s = SumNode([a, b, c])
assert s.vars() == {a, b, c}
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert (a + b * c).vars() == {a, b, c}
assert (a % 3 + b // 5).vars() == {a, b}
assert (a + b + c - a).vars() == {b, c}
def test_dedup(self):
a = Variable("a", 0, 10)
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):
a = Variable("a", 1, 8)
assert max(1, a, key=lambda x:x if isinstance(x, int) else x.max) == max(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == a
assert min(1, a, key=lambda x:x if isinstance(x, int) else x.max) == min(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == 1
class TestSymRender(unittest.TestCase):
def test_sym_render(self):
a = Variable("a", 1, 8)
b = Variable("b", 1, 10)
assert sym_render(a) == "a"
assert sym_render(1) == "1"
assert sym_render(a+1) == "(1+a)"
assert sym_render(a*b) == "(a*b)"
class TestSymInfer(unittest.TestCase):
def test_sym_infer(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
var_vals = {a: 2, b: 3, c: 4}
assert sym_infer(5, var_vals) == 5
assert sym_infer(a, var_vals) == 2
assert sym_infer(b, var_vals) == 3
assert sym_infer(a+b, var_vals) == 5
assert sym_infer(a-b, var_vals) == -1
assert sym_infer(a+b+c, var_vals) == 9
assert sym_infer(a*b, var_vals) == 6
assert sym_infer(a*b+c, var_vals) == 10
class TestSymbolicSymbolicOps(unittest.TestCase):
def test_node_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
assert 127 // (Variable("i", 1, 10)*128) == 0
assert 127 % (Variable("i", 1, 10)*128) == 127
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 // (Variable("i", 1, 10)*128) == 0
assert 0 % (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert idx0 % (i*3) == idx0
assert i // i == 1
assert i % i == 0
assert 128 // NumNode(4) == 32
assert 128 % NumNode(4) == 0
assert NumNode(128) // NumNode(4) == 32
assert NumNode(128) % NumNode(4) == 0
def test_mulnode_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, 31)
# assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
# assert (idx0*(i*4+4)) % (i+1) == 0
assert (idx0*i) % i == 0
def test_sumnode_divmod_sumnode(self):
i = Variable("i", 1, 10)
# idx0 = Variable("idx0", 0, 7)
# idx1 = Variable("idx1", 0, 3)
# idx2 = Variable("idx2", 0, i)
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
assert (i+1) // (i*128+128) == 0
assert (i+1) % (i*128+128) == (i+1)
# assert (i+1+idx2) // (i+1) == 1
# assert (i+1+idx2) % (i+1) == idx2
# assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
# assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
# assert (i*128+128)*2 // (i*128+128) == 2
# assert (i*128+128)*2 % (i*128+128) == 0
def test_sumnode_div_numnode_no_factoring(self):
gid = Variable("gid", 0, 1023)
lid = Variable("lid", 0, 3)
expr_before_div = NumNode(-1019)-4*lid-gid
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
def test_mod_node_max(self):
i = Variable("i", 1, 128)
gidx0 = Variable("gidx0", 0, i)
mod = gidx0 % 8
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
mod = gidx0 % 2
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
gidx0 = Variable("gidx0", 0, i*8+7)
mod = gidx0 % 8
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
mod = gidx0 % 2
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
def test_node_lt_node(self):
a = Variable("a", 1, 5)
b = Variable("b", 6, 9)
c = Variable("c", 1, 10)
d = Variable("d", 5, 10)
# if the comparison output is always the same, it folds to num
assert create_lt_node(a, b) == NumNode(1)
assert create_lt_node(b, a) == NumNode(0)
assert create_lt_node(d, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
# if it remains as a LtNode, (min, max) == (0, 1)
a_lt_c = create_lt_node(a, c)
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
# same when comparing with a constant
a_lt_3 = create_lt_node(a, 3)
assert a_lt_3.min == 0 and a_lt_3.max == 1
def test_sumnode_mulnode_lt(self):
a = Variable("a", 1, 2)
b = Variable("b", 1, 2)
c = Variable("c", 1, 2)
x = SumNode([MulNode(a, b), c])
with self.assertRaises(AssertionError):
create_lt_node(x, 3)
def test_nested_variable_mod(self):
i = Variable("i", 1, 5)
idx0 = Variable("idx0", 0, i)
with self.assertRaises(AssertionError):
assert idx0 % 2 == idx0
def test_num_node_mul_node(self):
a = Variable("a", 1, 5)
b = NumNode(2) * a
assert b == a * 2
assert isinstance(b, MulNode)
b = NumNode(1) * a
assert b == a
assert isinstance(b, Variable)
b = NumNode(0) * a
assert b == 0
assert isinstance(b, NumNode)
def test_substitute(self):
a = Variable("idx0", 1, 3)
b = a + 1
c = b.substitute({a: NumNode(1)})
assert c == NumNode(2)
class TestSymbolicRealWorld(unittest.TestCase):
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)
gidx1 = Variable("gidx1", 0, 127)
gidx2 = Variable("gidx2", 0, 7)
lidx3 = Variable("lidx3", 0, 7)
lidx4 = Variable("lidx4", 0, 1)
lidx5 = Variable("lidx5", 0, 15)
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
print(idx.render())
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
if __name__ == '__main__':
unittest.main()

View File

@@ -6,6 +6,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint, sym_infer
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
@@ -97,7 +98,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
op_estimate, mem_estimate, lds_estimate)
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
def updated_vars(self, var_vals):
vals = [var_vals[v] for v in self.vars]
@@ -179,7 +180,7 @@ def _prepare_jit_inputs(args, kwargs):
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, UOp))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device

View File

@@ -146,7 +146,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if buf.op is MetaOps.CONST:
if isinstance(val:=buf.arg, Variable): var_vals.update([val.unbind()])
if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()])
return ubuf.swizzle(unbound_st)
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):

View File

@@ -3,8 +3,8 @@ from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
from tinygrad.ops import identity_element, MathTrait, resolve
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.ops import identity_element, MathTrait, resolve, UOp
from tinygrad.shape.symbolic import sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from weakref import ref, ReferenceType, WeakValueDictionary
@@ -14,7 +14,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
dtype = to_dtype(dtype)
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
@@ -111,7 +111,7 @@ class LazyBuffer(MathTrait):
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def _copy(self, device:str) -> LazyBuffer:

View File

@@ -7,8 +7,8 @@ from dataclasses import dataclass, field
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
from tinygrad.shape.symbolic import Variable, sint
if TYPE_CHECKING:
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.codegen.kernel import Kernel
@@ -152,9 +152,12 @@ END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.E
# With True as the default, this matches the old symbolic behavior
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
def resolve(x, default:bool=True):
try: return bool(x)
except ValueError: return default
if not isinstance(x, UOp): return bool(x)
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
# NOTE: generating the text for the exception is expensive, so we do this
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
class UOp(MathTrait):
@@ -200,7 +203,7 @@ class UOp(MathTrait):
assert self.dtype in dtype, f"eval with wrong dtype {self}"
simple_self = self.simplify()
vmin, vmax = simple_self._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}")
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
@@ -252,11 +255,22 @@ class UOp(MathTrait):
@staticmethod
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
#if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.ASSIGN else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
def unbind(self) -> Tuple[Variable, int]:
assert self.op is UOps.ASSIGN and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
from tinygrad.shape.symbolic import Variable
return cast(Variable, self.src[0]), self.src[1].arg
@property
def val(self) -> int: return self.unbind()[1]
# TODO: this is context rewrite
def substitute(self, dvars:Dict[UOp, UOp]):
if self in dvars: return dvars[self]
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
@@ -268,10 +282,15 @@ class UOp(MathTrait):
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.ASSIGN and x.src[0].op is UOps.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.arg)
from tinygrad.shape.symbolic import Variable
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
@@ -332,6 +351,9 @@ class UOp(MathTrait):
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
return dtypes.min(self.dtype), dtypes.max(self.dtype)
def render(self, simplify=True) -> str:
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
return ret.arg if ret.op is UOps.NOOP else str(ret)
@dataclass(frozen=True)
class KernelInfo:
@@ -365,7 +387,8 @@ def exec_alu(op:Op, dtype:DType, operands):
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
if u.op is UOps.DEFINE_VAR: return u
#if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")
@@ -708,6 +731,8 @@ def type_verify(uops:List[UOp]):
chk = cast(bool, spec.rewrite(u))
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
# *** most of symbolic lives here now ***
simple_pm = PatternMatcher([
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
@@ -780,3 +805,16 @@ simple_pm = PatternMatcher([
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
])
# for debug
renderer = PatternMatcher([
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
(UPat(UOps.ASSIGN, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.ADD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}+{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MUL, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.IDIV, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}//{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MOD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}%{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
])

View File

@@ -1,26 +1,15 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, Any
from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve
from tinygrad.codegen.uopgraph import sym, _get_chain
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp.define_var(self.expr, dtypes.int, self.min, self.max),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)

View File

@@ -1,341 +1,39 @@
from __future__ import annotations
import functools
from math import gcd
from tinygrad.helpers import partition, all_int
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
from typing import Union, Optional, Dict, cast
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, exec_alu, ConstType
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
sint = Union[int, UOp]
class Node:
b: Union[Node, int]
min: int
max: sint
# broken
Node = UOp
MulNode = UOp
SumNode = UOp
DivNode = UOp
ModNode = UOp
LtNode = UOp
AndNode = UOp
def NumNode(val:int): return UOp.const(dtypes.int, val)
# helpers for the migration
class Variable(UOp):
def __reduce__(self): return Variable, self.arg
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
return UOp(UOps.ASSIGN, self.dtype, (self, self.const_like(val)))
@property
def vmin(self): return self.min
@property
def vmax(self): return self.max
def expr(self): return self.arg[0]
def render(self, ops=None, ctx=None) -> Any:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self) -> Set[Variable]: return set()
# substitute Variables with the values in var_vals
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
def __repr__(self): return self.render(ctx="REPR")
def __str__(self): return "<"+self.key+">"
def __hash__(self): return hash(self.key)
def __bool__(self):
if self.max == self.min: return self.max == 1
raise ValueError(f"couldn't resolve boolean expression {self}")
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b
def __rsub__(self, b:int): return -self+b
def __le__(self, b:Union[Node,int]): return self < (b+1)
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
def __lshift__(self, b:int): return self*2**b
# *** complex ops ***
def __rfloordiv__(self, b:int): return NumNode(b) // self
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
if self == b: return NumNode(1)
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
if b == 1: return self
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min//b
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_node(DivNode(self, b))
def __rmod__(self, b:int): return NumNode(b) % self
def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if b.__class__ is NumNode: return self % b.b
if self == b: return NumNode(0)
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1: return NumNode(0)
if isinstance(self.max, int) and isinstance(self.min, int):
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
if self.min < 0: return (self - ((self.min//b)*b)) % b
return create_node(ModNode(self, b))
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if not (x.max==x.min==0)]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]
mul_groups: Dict[Node, int] = {}
num_node_sum = 0
for node in SumNode(nodes).flat_components:
if node.__class__ is NumNode: num_node_sum += node.b
elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
else: mul_groups[node] = mul_groups.get(node, 0) + 1
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@staticmethod
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(x.max==0 for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, *args):
expr, nmin, nmax = args
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
@property
def arg(self): return self.expr
def __init__(self, expr:str, nmin:int, nmax:sint):
self.expr, self.min, self.max = expr, nmin, nmax
self._val: Optional[int] = None
@property
def val(self):
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
return self._val
def bind(self, val):
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self._val = val
return self
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
def vars(self): return {self}
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
# try both the bound and unbound Variable
return var_vals.get(self, var_vals.get(Variable(self.expr, self.min, self.max), self))
class NumNode(Node):
def __init__(self, num:int):
assert isinstance(num, int), f"{num} is not an int"
self.b:int = num
self.min, self.max = num, num
def bind(self, val):
assert self.b == val, f"cannot bind {val} to {self}"
return self
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
def __eq__(self, other): return self.b == other
def __hash__(self): return hash(self.b) # needed with __eq__ override
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
def create_lt_node(lhs:Node, b:Union[Node, int]):
if isinstance(lhs, SumNode):
if isinstance(b, int):
new_sum = []
for x in lhs.nodes:
# TODO: should we just force the last one to always be the number
if isinstance(x, NumNode): b -= x.b
else: new_sum.append(x)
lhs = Node.sum(new_sum)
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
all_others = Node.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
if isinstance(lhs, MulNode):
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
sgn = 1 if lhs.b > 0 else -1
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
return create_node(LtNode(lhs, b))
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
class OpNode(Node):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
class LtNode(OpNode):
def get_bounds(self) -> Tuple[int, int]:
if self.a == self.b: return (0, 0)
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
if all_int([self.a.max, self.b.min]) and self.a.max < self.b.min: return (1, 1)
if all_int([self.a.min, self.b.max]) and self.a.min >= self.b.max: return (0, 0)
return (0, 1)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
class MulNode(OpNode):
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
def get_bounds(self) -> Tuple[int, sint]:
assert self.a.min >= 0
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
class DivNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, sint]:
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min//self.b, self.a.max//self.b
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
class ModNode(OpNode):
def __mod__(self, b: Union[Node, int]):
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
return Node.__mod__(self, b)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, sint]:
assert self.a.min >= 0 and isinstance(self.b, int)
if all_int([self.a.max, self.a.min, self.b]):
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
return (self.a.min%self.b, self.a.max%self.b)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
class RedNode(Node):
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = self.get_bounds()
def vars(self) -> Set[Variable]: return set().union(*[x.vars() for x in self.nodes])
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
class SumNode(RedNode):
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
if self == b: return NumNode(1)
fully_divided: List[Node] = []
rest: List[Node] = []
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0: fully_divided.append(x // b)
else: rest.append(x)
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
return Node.__floordiv__(self, b, False)
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
_gcd = b
divisor = 1
for x in self.flat_components:
if x.__class__ in (NumNode, MulNode):
if x.b % b == 0: fully_divided.append(x // b)
else:
if x.__class__ is NumNode and (div := x.b // b):
fully_divided.append(NumNode(div))
x = NumNode(x.b - b * div)
rest.append(x)
if isinstance(x.b, int):
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
else:
_gcd = 1
else:
rest.append(x)
_gcd = 1
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mod__(self, b: Union[Node, int]):
if self == b: return NumNode(0)
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
return Node.__mod__(new_sum, b)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
return Node.sum([node.substitute(var_vals) for node in self.nodes])
# recursively expand sumnode components
# TODO: can remove this if there's no SumNode inside SumNode
@property
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
class AndNode(RedNode):
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
return Node.ands([node.substitute(var_vals) for node in self.nodes])
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
if isinstance(a, (int, float)): return a
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
return ret.b
# symbolic int, these are allowed in a Tensor shape
sint = Union[int, Variable, MulNode, SumNode]
def render_mulnode(node:MulNode, ops, ctx):
# TODO: add ProdNode and remove this case
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
render_python: Dict[Type, Callable[..., str]] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
MulNode: render_mulnode,
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
}
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
if uop.op == UOps.CONST: return uop.arg
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)]
if uop.op == UOps.ASSIGN: return uop.src[1].arg # bound variable returns bound value
if uop.op == UOps.ALU:
src_values = [sym_infer(src, var_vals) for src in uop.src]
return exec_alu(uop.arg, uop.dtype, src_values)
raise NotImplementedError(f"Unsupported UOp {uop.op}")

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.ops import resolve
from typing import Tuple, List, Optional, Dict, Set, cast, Union
from tinygrad.ops import resolve, UOp
from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
@@ -114,6 +114,13 @@ class View:
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
strides = tuple(0 if e else st for st,e in zip(strides, elim))
# simplify as we go
if isinstance(offset, UOp): offset = cast(Union[UOp, int], offset.ssimplify())
"""
shape = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in shape)
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
"""
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
return View(shape, strides, offset, mask, contiguous)
@@ -161,7 +168,7 @@ class View:
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, Node]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
extents.append((merged_size, merged_term))
@@ -220,7 +227,7 @@ class View:
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
shape = [y-x for x,y in arg]
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
@@ -246,9 +253,12 @@ class View:
if 0 in self.shape:
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
return View.create(new_shape)
assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
# TODO: this resolve might be wrong
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
# NOTE: can the mask ever be (0,0)?
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
return View.create(new_shape, self.strides, self.offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -278,8 +288,8 @@ class View:
return View.create(new_shape)
# check for the same size
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if resolve(prod(self.shape) != prod(new_shape), False):
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None

View File

@@ -10,9 +10,9 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, truncate, smax, resolve
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
from tinygrad.shape.symbolic import sint, Variable, Node
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
@@ -109,7 +109,7 @@ class Tensor:
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, Variable, pathlib.Path], # type: ignore [name-defined] # noqa: F821
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, UOp, pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
@@ -129,7 +129,7 @@ class Tensor:
# create a LazyBuffer from the different types of inputs
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
elif isinstance(data, UOp): data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
@@ -367,11 +367,15 @@ class Tensor:
return self
@staticmethod
def from_node(y:Node, **kwargs) -> Tensor:
if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
def from_node(y:UOp, **kwargs) -> Tensor:
# NOTE: we only support Tensors from DEFINE_VAR or CONST
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is UOps.ASSIGN:
assert y.src[0].op is UOps.DEFINE_VAR
return Tensor(y, **kwargs, requires_grad=False)
if y.op is UOps.ALU:
if y.arg is BinaryOps.MUL: return Tensor.from_node(y.src[0]) * Tensor.from_node(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_node(y.src[0]) + Tensor.from_node(y.src[1])
raise RuntimeError(f"unhandled Node {y}")
# ***** creation entrypoint *****
@@ -2680,7 +2684,8 @@ class Tensor:
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
padded, _ = _pad_left(self.shape, shape)
# for each dimension, check either from_ is 1, or it does not change
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(padded), shape=shape)
def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]: