mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
switch symbolic from old to uops, final PR (#6872)
* switch symbolic from old to uops, final PR
* two wrong answers
* not needed resolves
* symbolic ops passes
* symbolic ops passes
* progress
* tests pass (almost)
* fix last test
* fix some tests
* global binding and unbinding
* Revert "global binding and unbinding"
This reverts commit 9456725630.
* that test works now
* vars on uop doesn't recurse
* fix fuzzer
* update
* fix type
* fix gpt, it's UOp now
* ssimplify symbolics
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
|
||||
from tinygrad.nn import Embedding, Linear, LayerNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
@@ -75,9 +76,9 @@ class Transformer:
|
||||
self.lm_head = Linear(dim, vocab_size, bias=False)
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Union[Tensor,Variable], start_pos:Variable, temperature:float=0.0):
|
||||
def forward(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0):
|
||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
if isinstance(tokens, Variable):
|
||||
if isinstance(tokens, UOp):
|
||||
seqlen = 1
|
||||
tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
|
||||
else:
|
||||
@@ -107,8 +108,8 @@ class Transformer:
|
||||
ret = (logits / temperature).softmax().multinomial()
|
||||
return ret.flatten().realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
forward = (self.forward_jit if JIT and (isinstance(tokens, Variable) or tokens.shape[1] == 1) else self.forward)
|
||||
def __call__(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
forward = (self.forward_jit if JIT and (isinstance(tokens, UOp) or tokens.shape[1] == 1) else self.forward)
|
||||
return forward(tokens, start_pos, temperature)
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
@@ -71,8 +71,8 @@ class Attention:
|
||||
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
|
||||
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xk
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xv
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
@@ -172,10 +172,10 @@ class Transformer:
|
||||
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
|
||||
return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
10
test/external/fuzz_symbolic.py
vendored
10
test/external/fuzz_symbolic.py
vendored
@@ -40,6 +40,11 @@ def gt(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr > rng, rng
|
||||
|
||||
# NOTE: you have to replace these for this test to pass
|
||||
from tinygrad.ops import python_alu, BinaryOps
|
||||
python_alu[BinaryOps.MOD] = lambda x,y: x%y
|
||||
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num, mod]
|
||||
for _ in range(1000):
|
||||
@@ -65,5 +70,8 @@ if __name__ == "__main__":
|
||||
rn = 0
|
||||
for t,r in zip(tape, rngs): rn, _ = t(rn, r)
|
||||
num = eval(expr.render())
|
||||
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}"
|
||||
if num != rn:
|
||||
unsimplified_num = eval(expr.render(simplify=False))
|
||||
assert unsimplified_num == rn, "UNSIMPLIFIED MISMATCH!"
|
||||
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}\n{expr.render(simplify=False)}"
|
||||
if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}")
|
||||
|
||||
@@ -4,25 +4,21 @@ from tinygrad import Tensor, Variable
|
||||
|
||||
class TestTensorVariable(unittest.TestCase):
|
||||
def test_add_tvar(self):
|
||||
vv = Variable("a", 0, 10)
|
||||
vv.bind(1)
|
||||
vv = Variable("a", 0, 10).bind(1)
|
||||
ret = (Tensor(vv) + 3).item()
|
||||
assert ret == 4
|
||||
|
||||
def test_inner_tvar_node(self):
|
||||
vv = Variable("w", 0, 10)
|
||||
vv.bind(2)
|
||||
vv = Variable("w", 0, 10).bind(2)
|
||||
ret = Tensor.from_node(vv * 4).item()
|
||||
assert ret == 8
|
||||
|
||||
def test_inner_tvar_mul(self):
|
||||
vv = Variable("w", 0, 10)
|
||||
vv.bind(2)
|
||||
vv = Variable("w", 0, 10).bind(2)
|
||||
assert (Tensor(3) * vv).item() == 6
|
||||
|
||||
def test_inner_tvar_mul_node(self):
|
||||
vv = Variable("w", 0, 10)
|
||||
vv.bind(2)
|
||||
vv = Variable("w", 0, 10).bind(2)
|
||||
assert (Tensor(3) * (vv * 4)).item() == 24
|
||||
|
||||
def test_symbolic_mean(self):
|
||||
@@ -45,14 +41,10 @@ class TestTensorVariable(unittest.TestCase):
|
||||
ret = t.mean(axis=1).reshape(2, 1).numpy()
|
||||
assert np.all(ret == 1)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_symbolic_mean_2d_add(self):
|
||||
add_term = Variable("c", 0, 10)
|
||||
add_term.bind(1)
|
||||
vv = Variable("a", 1, 10)
|
||||
vv.bind(1)
|
||||
vv2 = Variable("b", 1, 10)
|
||||
vv2.bind(1)
|
||||
add_term = Variable("c", 0, 10).bind(1)
|
||||
vv = Variable("a", 1, 10).bind(1)
|
||||
vv2 = Variable("b", 1, 10).bind(1)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
@@ -65,8 +57,7 @@ class TestTensorVariable(unittest.TestCase):
|
||||
|
||||
@unittest.skip("symbolic arange isn't supported")
|
||||
def test_symbolic_arange(self):
|
||||
vv = Variable("a", 1, 10)
|
||||
vv.bind(2)
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
ret = Tensor.arange(0, vv)
|
||||
ret.realize()
|
||||
|
||||
|
||||
@@ -1,575 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
class TestSymbolicPickle(unittest.TestCase):
|
||||
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
||||
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
||||
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
self.assertEqual(v.render(), s)
|
||||
self.assertEqual(v.min, n)
|
||||
self.assertEqual(v.max, m)
|
||||
|
||||
def test_ge(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a*-1)<-7)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a*-1)<-3)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
|
||||
|
||||
def test_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
|
||||
|
||||
def test_ge_divides(self):
|
||||
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
|
||||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
|
||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
|
||||
self.helper_test_variable(expr//4, 0, 0, "0")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
|
||||
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
|
||||
|
||||
def test_div_reduction(self):
|
||||
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
|
||||
|
||||
def test_var_becomes_num(self):
|
||||
self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2")
|
||||
|
||||
def test_equality(self):
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, 3)
|
||||
assert idx1 == idx1
|
||||
assert idx1 != idx2
|
||||
assert idx1*4 == idx1*4
|
||||
assert idx1*4 != idx1*3
|
||||
assert idx1*4 != idx1+4
|
||||
assert idx1*4 != idx2*4
|
||||
assert idx1+idx2 == idx1+idx2
|
||||
assert idx1+idx2 == idx2+idx1
|
||||
assert idx1+idx2 != idx2
|
||||
assert idx1*idx2 == idx2*idx1
|
||||
|
||||
def test_numnode_eq_int(self):
|
||||
n1 = NumNode(1)
|
||||
n2 = NumNode(2)
|
||||
assert n1 == 1
|
||||
assert n2 == 2
|
||||
assert n1 != n2
|
||||
assert hash(n1) == hash(1)
|
||||
assert hash(n2) == hash(2)
|
||||
|
||||
def test_factorize(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
|
||||
|
||||
def test_factorize_no_mul(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
|
||||
|
||||
def test_neg(self):
|
||||
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
|
||||
|
||||
def test_add_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
|
||||
|
||||
def test_add_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(1+a)")
|
||||
|
||||
def test_sub_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
|
||||
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)")
|
||||
|
||||
def test_add_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a, 0, 16, "(a*2)")
|
||||
|
||||
def test_sub_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a-a, 0, 0, "0")
|
||||
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
|
||||
|
||||
def test_mul_0(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
|
||||
|
||||
def test_mul_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
|
||||
|
||||
def test_mul_neg_1(self):
|
||||
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
|
||||
|
||||
def test_mul_2(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
|
||||
|
||||
def test_div_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
|
||||
|
||||
def test_mod_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
|
||||
|
||||
def test_add_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
|
||||
|
||||
def test_div_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
|
||||
|
||||
def test_div_neg_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
|
||||
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
def test_sum_div_mod_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
|
||||
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
|
||||
def test_sum_div_trim_const(self):
|
||||
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((1+a+b)//4)")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
def test_mod_factor(self):
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
# This is mod reduction
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(-1+a)")
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
|
||||
|
||||
def test_sum_div_const_big(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, "(((a*4)+b)<16)")
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
|
||||
|
||||
def test_mul_mod_large(self):
|
||||
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
|
||||
|
||||
def test_mul_mod_small(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
|
||||
|
||||
def test_mod_mod(self):
|
||||
self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
|
||||
self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
|
||||
self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)")
|
||||
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
|
||||
|
||||
def test_mul_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
|
||||
def test_mul_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
def test_distribute_mul(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)")
|
||||
|
||||
def test_mod_mul_sum(self):
|
||||
self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
|
||||
|
||||
def test_sum_0(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
|
||||
def test_mod_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
|
||||
|
||||
def test_big_mod(self):
|
||||
# NOTE: we no longer support negative variables
|
||||
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_ge_remove(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
|
||||
|
||||
def test_lt_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
|
||||
|
||||
def test_lt_simple_factor(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
|
||||
"(((a*3)+(b*3))<4)")
|
||||
|
||||
def test_lt_sum_factor_rhs_partial(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<2)")
|
||||
|
||||
def test_lt_sum_factor_rhs_all(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
|
||||
def test_and_remove(self):
|
||||
self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
|
||||
def test_mod_factor_negative(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
|
||||
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
def test_div_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
def test_mod_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
||||
def test_add_div(self):
|
||||
# careful about the lower bounds and upper bounds
|
||||
self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "(((2+a)//4)+-1)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "(((3+a)//4)+-1)")
|
||||
self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((1+a)//4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((2+a)//4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((3+a)//4)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)")
|
||||
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((1+a)//4)+1)")
|
||||
|
||||
def test_mul_div_factor_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div_factor_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_sum_div_partial_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
def test_div_numerator_negative(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
def test_div_neg_cancel(self):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
|
||||
|
||||
def test_sum_div_big_const(self):
|
||||
gidx0 = Variable("gidx0", 0, 24)
|
||||
self.helper_test_variable((gidx0+19)//20, 0, 2, "((19+gidx0)//20)")
|
||||
self.helper_test_variable((gidx0+20)//20, 1, 2, "((gidx0//20)+1)")
|
||||
self.helper_test_variable((gidx0+21)//20, 1, 2, "(((1+gidx0)//20)+1)")
|
||||
|
||||
def test_sum_div_complex1(self):
|
||||
gidx0 = Variable("gidx0", 0, 24)
|
||||
gidx1 = Variable("gidx1", 0, 1)
|
||||
gidx2 = Variable("gidx2", 0, 255)
|
||||
lidx0 = Variable("lidx0", 0, 1)
|
||||
lidx1 = Variable("lidx1", 0, 15)
|
||||
lidx2 = Variable("lidx2", 0, 3)
|
||||
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
|
||||
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx1*8)+(gidx2*32)+(lidx0*16))")
|
||||
|
||||
def test_sum_div_complex2(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 1)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0*4+lidx2*2+1)//10, 0, 3, "(((gidx0*2)+lidx2)//5)")
|
||||
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//10, 0, 3, "(((gidx0*2)+lidx2)//5)")
|
||||
self.helper_test_variable((gidx0*2+lidx2)//10, 0, 1, "(gidx0//5)")
|
||||
|
||||
def test_sum_div_complex3(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
|
||||
def test_sum_mul_distribute(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
|
||||
|
||||
def test_variable_divmod(self):
|
||||
start_pos = Variable("start_pos", 0, 127)
|
||||
v = start_pos + 1
|
||||
idx0 = Variable("idx0", 0, 2)
|
||||
idx1 = Variable("idx1", 0, start_pos)
|
||||
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
|
||||
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
#MIN, MAX = -10, 10
|
||||
MIN, MAX = 0, 10
|
||||
# one number
|
||||
for i in range(MIN, MAX):
|
||||
v = f(NumNode(i))
|
||||
#print(i, f(i), v.min, v.max)
|
||||
self.assertEqual(v.min, v.max)
|
||||
self.assertEqual(v.min, f(i))
|
||||
for kmin in range(MIN, MAX):
|
||||
for kmax in range(MIN, MAX):
|
||||
if kmin > kmax: continue
|
||||
v = f(Variable("tmp", kmin, kmax))
|
||||
values = [f(rv) for rv in range(kmin, kmax+1)]
|
||||
# the min and max may not be exact
|
||||
self.assertLessEqual(v.min, min(values))
|
||||
self.assertGreaterEqual(v.max, max(values))
|
||||
|
||||
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
||||
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
||||
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
|
||||
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
|
||||
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
|
||||
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
|
||||
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
|
||||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = NumNode(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == set()
|
||||
assert a.vars() == a.vars() == {a}
|
||||
m = MulNode(a, 3)
|
||||
assert m.vars() == {a}
|
||||
s = SumNode([a, b, c])
|
||||
assert s.vars() == {a, b, c}
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert (a + b * c).vars() == {a, b, c}
|
||||
assert (a % 3 + b // 5).vars() == {a, b}
|
||||
assert (a + b + c - a).vars() == {b, c}
|
||||
|
||||
def test_dedup(self):
|
||||
a = Variable("a", 0, 10)
|
||||
assert (a * a).vars() == {a}
|
||||
assert (a//4 + a//6).vars() == {a}
|
||||
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
a = Variable("a", 1, 8)
|
||||
assert max(1, a, key=lambda x:x if isinstance(x, int) else x.max) == max(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == a
|
||||
assert min(1, a, key=lambda x:x if isinstance(x, int) else x.max) == min(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == 1
|
||||
|
||||
class TestSymRender(unittest.TestCase):
|
||||
def test_sym_render(self):
|
||||
a = Variable("a", 1, 8)
|
||||
b = Variable("b", 1, 10)
|
||||
assert sym_render(a) == "a"
|
||||
assert sym_render(1) == "1"
|
||||
assert sym_render(a+1) == "(1+a)"
|
||||
assert sym_render(a*b) == "(a*b)"
|
||||
|
||||
class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
var_vals = {a: 2, b: 3, c: 4}
|
||||
assert sym_infer(5, var_vals) == 5
|
||||
assert sym_infer(a, var_vals) == 2
|
||||
assert sym_infer(b, var_vals) == 3
|
||||
assert sym_infer(a+b, var_vals) == 5
|
||||
assert sym_infer(a-b, var_vals) == -1
|
||||
assert sym_infer(a+b+c, var_vals) == 9
|
||||
assert sym_infer(a*b, var_vals) == 6
|
||||
assert sym_infer(a*b+c, var_vals) == 10
|
||||
|
||||
class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
def test_node_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert 127 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 0 % (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert idx0 % (i*3) == idx0
|
||||
assert i // i == 1
|
||||
assert i % i == 0
|
||||
assert 128 // NumNode(4) == 32
|
||||
assert 128 % NumNode(4) == 0
|
||||
assert NumNode(128) // NumNode(4) == 32
|
||||
assert NumNode(128) % NumNode(4) == 0
|
||||
|
||||
def test_mulnode_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, 31)
|
||||
# assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
|
||||
# assert (idx0*(i*4+4)) % (i+1) == 0
|
||||
assert (idx0*i) % i == 0
|
||||
|
||||
def test_sumnode_divmod_sumnode(self):
|
||||
i = Variable("i", 1, 10)
|
||||
# idx0 = Variable("idx0", 0, 7)
|
||||
# idx1 = Variable("idx1", 0, 3)
|
||||
# idx2 = Variable("idx2", 0, i)
|
||||
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
|
||||
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
|
||||
assert (i+1) // (i*128+128) == 0
|
||||
assert (i+1) % (i*128+128) == (i+1)
|
||||
# assert (i+1+idx2) // (i+1) == 1
|
||||
# assert (i+1+idx2) % (i+1) == idx2
|
||||
# assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
|
||||
# assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
|
||||
# assert (i*128+128)*2 // (i*128+128) == 2
|
||||
# assert (i*128+128)*2 % (i*128+128) == 0
|
||||
|
||||
def test_sumnode_div_numnode_no_factoring(self):
|
||||
gid = Variable("gid", 0, 1023)
|
||||
lid = Variable("lid", 0, 3)
|
||||
expr_before_div = NumNode(-1019)-4*lid-gid
|
||||
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
|
||||
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
|
||||
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
|
||||
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
|
||||
|
||||
def test_mod_node_max(self):
|
||||
i = Variable("i", 1, 128)
|
||||
gidx0 = Variable("gidx0", 0, i)
|
||||
mod = gidx0 % 8
|
||||
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
|
||||
mod = gidx0 % 2
|
||||
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
|
||||
|
||||
gidx0 = Variable("gidx0", 0, i*8+7)
|
||||
mod = gidx0 % 8
|
||||
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
|
||||
mod = gidx0 % 2
|
||||
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
|
||||
|
||||
def test_node_lt_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = Variable("b", 6, 9)
|
||||
c = Variable("c", 1, 10)
|
||||
d = Variable("d", 5, 10)
|
||||
# if the comparison output is always the same, it folds to num
|
||||
assert create_lt_node(a, b) == NumNode(1)
|
||||
assert create_lt_node(b, a) == NumNode(0)
|
||||
assert create_lt_node(d, a) == NumNode(0)
|
||||
assert create_lt_node(a, a) == NumNode(0)
|
||||
assert create_lt_node(a, a) == NumNode(0)
|
||||
# if it remains as a LtNode, (min, max) == (0, 1)
|
||||
a_lt_c = create_lt_node(a, c)
|
||||
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
|
||||
# same when comparing with a constant
|
||||
a_lt_3 = create_lt_node(a, 3)
|
||||
assert a_lt_3.min == 0 and a_lt_3.max == 1
|
||||
|
||||
def test_sumnode_mulnode_lt(self):
|
||||
a = Variable("a", 1, 2)
|
||||
b = Variable("b", 1, 2)
|
||||
c = Variable("c", 1, 2)
|
||||
x = SumNode([MulNode(a, b), c])
|
||||
with self.assertRaises(AssertionError):
|
||||
create_lt_node(x, 3)
|
||||
|
||||
def test_nested_variable_mod(self):
|
||||
i = Variable("i", 1, 5)
|
||||
idx0 = Variable("idx0", 0, i)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert idx0 % 2 == idx0
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = NumNode(2) * a
|
||||
assert b == a * 2
|
||||
assert isinstance(b, MulNode)
|
||||
b = NumNode(1) * a
|
||||
assert b == a
|
||||
assert isinstance(b, Variable)
|
||||
b = NumNode(0) * a
|
||||
assert b == 0
|
||||
assert isinstance(b, NumNode)
|
||||
|
||||
def test_substitute(self):
|
||||
a = Variable("idx0", 1, 3)
|
||||
b = a + 1
|
||||
c = b.substitute({a: NumNode(1)})
|
||||
assert c == NumNode(2)
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
def test_resnet_half(self):
|
||||
gidx0 = Variable("gidx0", 0, 3)
|
||||
gidx1 = Variable("gidx1", 0, 127)
|
||||
gidx2 = Variable("gidx2", 0, 7)
|
||||
lidx3 = Variable("lidx3", 0, 7)
|
||||
lidx4 = Variable("lidx4", 0, 1)
|
||||
lidx5 = Variable("lidx5", 0, 15)
|
||||
|
||||
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
|
||||
print(idx.render())
|
||||
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
|
||||
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -6,6 +6,7 @@ from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, ssimplify
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint, sym_infer
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
|
||||
@@ -97,7 +98,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
||||
op_estimate, mem_estimate, lds_estimate)
|
||||
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
||||
|
||||
def updated_vars(self, var_vals):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
@@ -179,7 +180,7 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, UOp))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(val:=buf.arg, Variable): var_vals.update([val.unbind()])
|
||||
if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()])
|
||||
return ubuf.swizzle(unbound_st)
|
||||
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
@@ -14,7 +14,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
|
||||
dtype = to_dtype(dtype)
|
||||
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
||||
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
||||
@@ -111,7 +111,7 @@ class LazyBuffer(MathTrait):
|
||||
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
|
||||
@@ -7,8 +7,8 @@ from dataclasses import dataclass, field
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
@@ -152,9 +152,12 @@ END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.E
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
def resolve(x, default:bool=True):
|
||||
try: return bool(x)
|
||||
except ValueError: return default
|
||||
if not isinstance(x, UOp): return bool(x)
|
||||
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
||||
# NOTE: generating the text for the exception is expensive, so we do this
|
||||
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
class UOp(MathTrait):
|
||||
@@ -200,7 +203,7 @@ class UOp(MathTrait):
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
simple_self = self.simplify()
|
||||
vmin, vmax = simple_self._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}")
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
|
||||
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
|
||||
return vmin
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
@@ -252,11 +255,22 @@ class UOp(MathTrait):
|
||||
@staticmethod
|
||||
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
|
||||
#if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.ASSIGN else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.op is UOps.ASSIGN and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
return cast(Variable, self.src[0]), self.src[1].arg
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
# TODO: this is context rewrite
|
||||
def substitute(self, dvars:Dict[UOp, UOp]):
|
||||
if self in dvars: return dvars[self]
|
||||
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
@@ -268,10 +282,15 @@ class UOp(MathTrait):
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def vars(self) -> Set[UOp]:
|
||||
bound_vars = set([x for x in self.sparents if x.op is UOps.ASSIGN and x.src[0].op is UOps.DEFINE_VAR])
|
||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.arg)
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
@@ -332,6 +351,9 @@ class UOp(MathTrait):
|
||||
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
|
||||
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
def render(self, simplify=True) -> str:
|
||||
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
|
||||
return ret.arg if ret.op is UOps.NOOP else str(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
@@ -365,7 +387,8 @@ def exec_alu(op:Op, dtype:DType, operands):
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
|
||||
if u.op is UOps.DEFINE_VAR: return u
|
||||
#if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
@@ -708,6 +731,8 @@ def type_verify(uops:List[UOp]):
|
||||
chk = cast(bool, spec.rewrite(u))
|
||||
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
@@ -780,3 +805,16 @@ simple_pm = PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
])
|
||||
|
||||
# for debug
|
||||
renderer = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
|
||||
(UPat(UOps.ASSIGN, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.ADD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}+{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MUL, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.IDIV, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}//{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MOD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}%{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
|
||||
])
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Any
|
||||
from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve
|
||||
from tinygrad.codegen.uopgraph import sym, _get_chain
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
|
||||
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp.define_var(self.expr, dtypes.int, self.min, self.max),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
|
||||
@@ -1,341 +1,39 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from math import gcd
|
||||
from tinygrad.helpers import partition, all_int
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
|
||||
from typing import Union, Optional, Dict, cast
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, exec_alu, ConstType
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
sint = Union[int, UOp]
|
||||
|
||||
class Node:
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: sint
|
||||
# broken
|
||||
Node = UOp
|
||||
MulNode = UOp
|
||||
SumNode = UOp
|
||||
DivNode = UOp
|
||||
ModNode = UOp
|
||||
LtNode = UOp
|
||||
AndNode = UOp
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
|
||||
# helpers for the migration
|
||||
class Variable(UOp):
|
||||
def __reduce__(self): return Variable, self.arg
|
||||
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
|
||||
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
||||
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
|
||||
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
||||
def bind(self, val:int):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
|
||||
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
|
||||
return UOp(UOps.ASSIGN, self.dtype, (self, self.const_like(val)))
|
||||
@property
|
||||
def vmin(self): return self.min
|
||||
@property
|
||||
def vmax(self): return self.max
|
||||
def expr(self): return self.arg[0]
|
||||
|
||||
def render(self, ops=None, ctx=None) -> Any:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self) -> Set[Variable]: return set()
|
||||
# substitute Variables with the values in var_vals
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
def __repr__(self): return self.render(ctx="REPR")
|
||||
def __str__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return hash(self.key)
|
||||
def __bool__(self):
|
||||
if self.max == self.min: return self.max == 1
|
||||
raise ValueError(f"couldn't resolve boolean expression {self}")
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
|
||||
def __radd__(self, b:int): return self+b
|
||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||
def __rsub__(self, b:int): return -self+b
|
||||
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
|
||||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
def __lshift__(self, b:int): return self*2**b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __rfloordiv__(self, b:int): return NumNode(b) // self
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
|
||||
if self == b: return NumNode(1)
|
||||
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
||||
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
|
||||
if b == 1: return self
|
||||
|
||||
# the numerator of div is not allowed to be negative
|
||||
if self.min < 0:
|
||||
offset = self.min//b
|
||||
# factor out an "offset" to make the numerator positive. don't allowing factoring again
|
||||
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
||||
return create_node(DivNode(self, b))
|
||||
|
||||
def __rmod__(self, b:int): return NumNode(b) % self
|
||||
def __mod__(self, b:Union[Node,int]):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self % b.b
|
||||
if self == b: return NumNode(0)
|
||||
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} % {b}")
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if isinstance(self.max, int) and isinstance(self.min, int):
|
||||
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
||||
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
||||
return create_node(ModNode(self, b))
|
||||
|
||||
@staticmethod
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes = [x for x in nodes if not (x.max==x.min==0)]
|
||||
if not nodes: return NumNode(0)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
|
||||
mul_groups: Dict[Node, int] = {}
|
||||
num_node_sum = 0
|
||||
for node in SumNode(nodes).flat_components:
|
||||
if node.__class__ is NumNode: num_node_sum += node.b
|
||||
elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
|
||||
else: mul_groups[node] = mul_groups.get(node, 0) + 1
|
||||
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
||||
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
||||
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
||||
|
||||
@staticmethod
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any(x.max==0 for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
||||
|
||||
# 4 basic node types
|
||||
|
||||
class Variable(Node):
|
||||
def __new__(cls, *args):
|
||||
expr, nmin, nmax = args
|
||||
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
||||
@property
|
||||
def arg(self): return self.expr
|
||||
|
||||
def __init__(self, expr:str, nmin:int, nmax:sint):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self._val: Optional[int] = None
|
||||
@property
|
||||
def val(self):
|
||||
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
|
||||
return self._val
|
||||
def bind(self, val):
|
||||
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self._val = val
|
||||
return self
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
return Variable(self.expr, self.min, self.max), self.val
|
||||
def vars(self): return {self}
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
||||
# try both the bound and unbound Variable
|
||||
return var_vals.get(self, var_vals.get(Variable(self.expr, self.min, self.max), self))
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def bind(self, val):
|
||||
assert self.b == val, f"cannot bind {val} to {self}"
|
||||
return self
|
||||
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return hash(self.b) # needed with __eq__ override
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
if ret.min == ret.max: return NumNode(ret.min)
|
||||
return ret
|
||||
|
||||
def create_lt_node(lhs:Node, b:Union[Node, int]):
|
||||
if isinstance(lhs, SumNode):
|
||||
if isinstance(b, int):
|
||||
new_sum = []
|
||||
for x in lhs.nodes:
|
||||
# TODO: should we just force the last one to always be the number
|
||||
if isinstance(x, NumNode): b -= x.b
|
||||
else: new_sum.append(x)
|
||||
lhs = Node.sum(new_sum)
|
||||
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
||||
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
||||
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
||||
all_others = Node.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
|
||||
if isinstance(lhs, MulNode):
|
||||
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
|
||||
sgn = 1 if lhs.b > 0 else -1
|
||||
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
|
||||
return create_node(LtNode(lhs, b))
|
||||
|
||||
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
||||
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
||||
|
||||
class LtNode(OpNode):
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if self.a == self.b: return (0, 0)
|
||||
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
if all_int([self.a.max, self.b.min]) and self.a.max < self.b.min: return (1, 1)
|
||||
if all_int([self.a.min, self.b.max]) and self.a.min >= self.b.max: return (0, 0)
|
||||
return (0, 1)
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
||||
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
||||
def get_bounds(self) -> Tuple[int, sint]:
|
||||
assert self.a.min >= 0
|
||||
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
||||
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, sint]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
|
||||
return Node.__mod__(self, b)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, sint]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
if all_int([self.a.max, self.a.min, self.b]):
|
||||
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
|
||||
return (self.a.min%self.b, self.a.max%self.b)
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]):
|
||||
self.nodes = nodes
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self) -> Set[Variable]: return set().union(*[x.vars() for x in self.nodes])
|
||||
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
||||
|
||||
class SumNode(RedNode):
|
||||
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
|
||||
if self == b: return NumNode(1)
|
||||
fully_divided: List[Node] = []
|
||||
rest: List[Node] = []
|
||||
if isinstance(b, Node):
|
||||
for x in self.flat_components:
|
||||
if x % b == 0: fully_divided.append(x // b)
|
||||
else: rest.append(x)
|
||||
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
|
||||
return Node.__floordiv__(self, b, False)
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
_gcd = b
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
if x.__class__ in (NumNode, MulNode):
|
||||
if x.b % b == 0: fully_divided.append(x // b)
|
||||
else:
|
||||
if x.__class__ is NumNode and (div := x.b // b):
|
||||
fully_divided.append(NumNode(div))
|
||||
x = NumNode(x.b - b * div)
|
||||
rest.append(x)
|
||||
if isinstance(x.b, int):
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
||||
else:
|
||||
_gcd = 1
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = 1
|
||||
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
|
||||
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
||||
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if self == b: return NumNode(0)
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
|
||||
return Node.__mod__(new_sum, b)
|
||||
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
||||
return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
# recursively expand sumnode components
|
||||
# TODO: can remove this if there's no SumNode inside SumNode
|
||||
@property
|
||||
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
||||
|
||||
class AndNode(RedNode):
|
||||
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
||||
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
||||
return Node.ands([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
||||
if isinstance(a, (int, float)): return a
|
||||
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
|
||||
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
||||
return ret.b
|
||||
|
||||
# symbolic int, these are allowed in a Tensor shape
|
||||
sint = Union[int, Variable, MulNode, SumNode]
|
||||
|
||||
def render_mulnode(node:MulNode, ops, ctx):
|
||||
# TODO: add ProdNode and remove this case
|
||||
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
|
||||
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
|
||||
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
|
||||
|
||||
render_python: Dict[Type, Callable[..., str]] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
|
||||
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
|
||||
else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
||||
MulNode: render_mulnode,
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
}
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
||||
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
|
||||
if uop.op == UOps.CONST: return uop.arg
|
||||
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)]
|
||||
if uop.op == UOps.ASSIGN: return uop.src[1].arg # bound variable returns bound value
|
||||
if uop.op == UOps.ALU:
|
||||
src_values = [sym_infer(src, var_vals) for src in uop.src]
|
||||
return exec_alu(uop.arg, uop.dtype, src_values)
|
||||
raise NotImplementedError(f"Unsupported UOp {uop.op}")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.ops import resolve
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from tinygrad.ops import resolve, UOp
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
||||
|
||||
@@ -114,6 +114,13 @@ class View:
|
||||
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
||||
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
||||
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
||||
# simplify as we go
|
||||
if isinstance(offset, UOp): offset = cast(Union[UOp, int], offset.ssimplify())
|
||||
"""
|
||||
shape = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in shape)
|
||||
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
||||
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
||||
"""
|
||||
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@@ -161,7 +168,7 @@ class View:
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
extents: List[Tuple[sint, Node]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_size *= s
|
||||
if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
|
||||
extents.append((merged_size, merged_term))
|
||||
@@ -220,7 +227,7 @@ class View:
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
||||
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
||||
@@ -246,9 +253,12 @@ class View:
|
||||
if 0 in self.shape:
|
||||
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
return View.create(new_shape)
|
||||
assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
# TODO: this resolve might be wrong
|
||||
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
||||
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
return View.create(new_shape, self.strides, self.offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@@ -278,8 +288,8 @@ class View:
|
||||
return View.create(new_shape)
|
||||
# check for the same size
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
||||
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if resolve(prod(self.shape) != prod(new_shape), False):
|
||||
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
@@ -10,9 +10,9 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node
|
||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
@@ -109,7 +109,7 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, Variable, pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, UOp, pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
@@ -129,7 +129,7 @@ class Tensor:
|
||||
# create a LazyBuffer from the different types of inputs
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
||||
elif isinstance(data, UOp): data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if dtype is None:
|
||||
@@ -367,11 +367,15 @@ class Tensor:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def from_node(y:Node, **kwargs) -> Tensor:
|
||||
if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
|
||||
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
|
||||
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
|
||||
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
|
||||
def from_node(y:UOp, **kwargs) -> Tensor:
|
||||
# NOTE: we only support Tensors from DEFINE_VAR or CONST
|
||||
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is UOps.ASSIGN:
|
||||
assert y.src[0].op is UOps.DEFINE_VAR
|
||||
return Tensor(y, **kwargs, requires_grad=False)
|
||||
if y.op is UOps.ALU:
|
||||
if y.arg is BinaryOps.MUL: return Tensor.from_node(y.src[0]) * Tensor.from_node(y.src[1])
|
||||
if y.arg is BinaryOps.ADD: return Tensor.from_node(y.src[0]) + Tensor.from_node(y.src[1])
|
||||
raise RuntimeError(f"unhandled Node {y}")
|
||||
|
||||
# ***** creation entrypoint *****
|
||||
@@ -2680,7 +2684,8 @@ class Tensor:
|
||||
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
||||
padded, _ = _pad_left(self.shape, shape)
|
||||
# for each dimension, check either from_ is 1, or it does not change
|
||||
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
||||
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
|
||||
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
||||
return F.Expand.apply(self.reshape(padded), shape=shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user