add dtypes.index (#12015)

* add dtypes.index

* cast shape, stride and mask to dtypes.index in view.create

* move pm_lower_index_dtype to ops

* DEFINE_VAR is dtype.index by default

* merge var_val_using_str

* remove int from commutative

* fix test_rewrite_map

* change that to dtypes.index

* change some int to index

* shorten those

* remove old cast in renderer

* cleanup

* change that back

* add comment

* delete comment

* just delete those

* view doesnt have to cast anymore

* adjust comment
This commit is contained in:
Sieds Lykles
2025-09-06 06:03:44 +02:00
committed by GitHub
parent c6c16b2946
commit 581b2388c2
21 changed files with 181 additions and 155 deletions

View File

@@ -26,6 +26,10 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv
MOCKGPU = getenv("MOCKGPU")
class TestNaNEdgeCases(unittest.TestCase): class TestNaNEdgeCases(unittest.TestCase):
# we don't need more of these. it's unclear if torch's behavior is desired here # we don't need more of these. it's unclear if torch's behavior is desired here
@@ -167,34 +171,6 @@ class TestZeroFolding(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
(x % x).numpy() (x % x).numpy()
class TestArangeUOpValidationIssue(unittest.TestCase):
# these fail with UOp verification error.
# we don't need more of these involving arange
@unittest.expectedFailure
def test_large_arange_sum(self):
# Summing a huge arange should either succeed or raise a MemoryError.
n = 2**31 + 3
expected = (n - 1) * n // 2
out = Tensor.arange(n).sum().item()
self.assertEqual(out, expected)
@unittest.expectedFailure
def test_large_arange_index(self):
# Indexing a huge arange should return the correct value instead of failing
# with a UOp verification error.
n = 2**31 + 3
out = Tensor.arange(n)[0].item()
self.assertEqual(out, 0)
@unittest.expectedFailure
def test_large_arange_permute(self):
# Permuting a huge tensor should not trigger UOp verification failures.
n = 2**31 + 3
out = Tensor.arange(n).reshape(n, 1).permute(1, 0)
self.assertEqual(out.shape, (1, n))
out.realize()
class TestAssignIssues(unittest.TestCase): class TestAssignIssues(unittest.TestCase):
# these are good failures. i'm not sure we need more, but we need to fix these. # these are good failures. i'm not sure we need more, but we need to fix these.
@@ -230,10 +206,8 @@ class TestUOpValidationIssue(unittest.TestCase):
# these fail with UOp verification error. # these fail with UOp verification error.
# we want more of these with diverse errors! # we want more of these with diverse errors!
@unittest.expectedFailure @unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU, "hangs gpuocelot")
def test_tensor_index_overflow(self): def test_tensor_index_overflow(self):
# Advanced indexing on tensors expanded past int32 should not error, but
# tinygrad fails with a UOp verification error.
val = Tensor([1]) val = Tensor([1])
big = val.expand(2**31 + 3) big = val.expand(2**31 + 3)
idx = Tensor([0, 2**31 + 2]) idx = Tensor([0, 2**31 + 2])
@@ -273,4 +247,4 @@ class TestEdgeCases(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -928,6 +928,7 @@ class TestIdxUpcast(unittest.TestCase):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.expectedFailure # bug in gpu dims limiting
def test_int64_unsupported_overflow(self): def test_int64_unsupported_overflow(self):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048) self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)

View File

@@ -212,7 +212,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_same_fold(self): def test_where_same_fold(self):
v = UOp.variable('tmp', 0, 1) v = UOp.variable('tmp', 0, 1)
c0 = UOp(Ops.CONST, dtypes.int, arg=0) c0 = UOp(Ops.CONST, dtypes.index, arg=0)
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0)) vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1)) out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
@@ -398,7 +398,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
def test_depth_2_const_fold(self): def test_depth_2_const_fold(self):
v = UOp.variable("tmp", 0, 1) v = UOp.variable("tmp", 0, 1, dtypes.int)
c2 = UOp(Ops.CONST, dtypes.int, arg=2) c2 = UOp(Ops.CONST, dtypes.int, arg=2)
c4 = UOp(Ops.CONST, dtypes.int, arg=4) c4 = UOp(Ops.CONST, dtypes.int, arg=4)
vc = UOp(Ops.ADD, dtypes.int, (v, c2)) vc = UOp(Ops.ADD, dtypes.int, (v, c2))
@@ -417,6 +417,17 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([v.bitcast(dt)]) uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
def test_load_idx_becomes_int(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1)
l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),))
idx = l0 * 600
valid = (l0<-1).ne(True)&(l0<3000)
l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx, valid),))
uops = to_uops_list([l1])
for u in uops:
if u.op is Ops.INDEX: self.assertEqual(u.src[1].dtype, dtypes.int)
def test_in_out_of_bounds_access(self): def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
@@ -512,7 +523,7 @@ class TestUOpGraph(unittest.TestCase):
def test_in_out_bounds_access_with_mask(self): def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0") gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
to_uops_list([ld0, ld1]) to_uops_list([ld0, ld1])
@@ -536,9 +547,9 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0) glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0") gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index)
to_uops_list([ld1]) to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))

View File

@@ -97,19 +97,22 @@ class TestFoldingAndReduction(unittest.TestCase):
class TestModuloAndDivisionFolding(unittest.TestCase): class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_modulo_folding_with_define_var(self): def test_full_graph_rewrite_modulo_folding_with_define_var(self):
x_var_uop = UOp.variable('x', 0, 100) # index dtype because div-mod rules only work on index
x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index)
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
self.assertEqual(optimized_mod_uop.op, Ops.CONST) self.assertEqual(optimized_mod_uop.op, Ops.CONST)
self.assertEqual(optimized_mod_uop.arg, 2) self.assertEqual(optimized_mod_uop.arg, 2)
def test_full_graph_rewrite_division_folding_with_define_var(self): def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000) # index dtype because div-mod rules only work on index
n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, Ops.MUL) self.assertEqual(optimized_div_uop.op, Ops.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2) self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self): def test_full_graph_rewrite_complex_mod_div_folding(self):
k_var_uop = UOp.variable('k', 0, 50) # index dtype because div-mod rules only work on index
k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index)
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertEqual(optimized_div_uop.op, Ops.CONST)
self.assertEqual(optimized_div_uop.arg, 1) self.assertEqual(optimized_div_uop.arg, 1)
@@ -126,8 +129,9 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src)) if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src))
def test_full_graph_rewrite_modulo_large_divisor(self): def test_full_graph_rewrite_modulo_large_divisor(self):
# index dtype because div-mod rules only work on index
x_var_uop = UOp.variable('x', 1, 5) x_var_uop = UOp.variable('x', 1, 5)
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop) self.assertIs(apply_rewrite(x_var_uop.cast(dtypes.index) % 10).render(simplify=False), x_var_uop.render(simplify=False))
def test_full_graph_rewrite_division_with_remainder(self): def test_full_graph_rewrite_division_with_remainder(self):
x_var_uop = UOp.variable('x', 7, 9) x_var_uop = UOp.variable('x', 7, 9)

View File

@@ -46,8 +46,8 @@ class TestRewriteMap(unittest.TestCase):
def test_add_zero(self): def test_add_zero(self):
# Build a small graph: add(0, add(const=0, const=5)) # Build a small graph: add(0, add(const=0, const=5))
zero_node = UOp.const(dtypes.int, 0) zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.int, 5) five_node = UOp.const(dtypes.index, 5)
inner_add = zero_node + five_node inner_add = zero_node + five_node
root_add = zero_node + inner_add root_add = zero_node + inner_add
@@ -67,7 +67,7 @@ class TestRewriteMap(unittest.TestCase):
Test rewriting neg(neg(5)) => 5 using symbolic. Test rewriting neg(neg(5)) => 5 using symbolic.
""" """
# In some versions of TinyGrad, you might do: (-(-five_node)) # In some versions of TinyGrad, you might do: (-(-five_node))
five_node = UOp.const(dtypes.int, 5) five_node = UOp.const(dtypes.index, 5)
# If your code allows UOp(...), do that; else you might do something like: # If your code allows UOp(...), do that; else you might do something like:
# double_neg_five = -(-five_node) # double_neg_five = -(-five_node)
# But let's be explicit: # But let's be explicit:
@@ -85,8 +85,8 @@ class TestRewriteMap(unittest.TestCase):
""" """
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5 Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
""" """
zero_node = UOp.const(dtypes.int, 0) zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.int, 5) five_node = UOp.const(dtypes.index, 5)
neg_five = -five_node neg_five = -five_node
double_neg_five = -neg_five double_neg_five = -neg_five
root_add = zero_node + double_neg_five root_add = zero_node + double_neg_five
@@ -103,7 +103,7 @@ class TestRewriteMap(unittest.TestCase):
def test_multi_var_rewrites(self): def test_multi_var_rewrites(self):
x_var = UOp.variable('x', 0, 10) x_var = UOp.variable('x', 0, 10)
y_var = UOp.variable('y', -5, 5) y_var = UOp.variable('y', -5, 5)
zero_node = UOp.const(dtypes.int, 0) zero_node = UOp.const(dtypes.index, 0)
sum_with_zero = y_var + zero_node # (y + 0) sum_with_zero = y_var + zero_node # (y + 0)
combined = x_var + sum_with_zero # x + (y + 0) combined = x_var + sum_with_zero # x + (y + 0)
@@ -155,8 +155,8 @@ class TestRewriteMap(unittest.TestCase):
x_var = UOp.variable('x', 1, 10) x_var = UOp.variable('x', 1, 10)
y_var = UOp.variable('y', -5, 5) y_var = UOp.variable('y', -5, 5)
z_var = UOp.variable('z', 0, 5) z_var = UOp.variable('z', 0, 5)
zero_node = UOp.const(dtypes.int, 0) zero_node = UOp.const(dtypes.index, 0)
one_node = UOp.const(dtypes.int, 1) one_node = UOp.const(dtypes.index, 1)
# Build sub-expressions # Build sub-expressions
yz_sum = y_var + z_var # (y + z) yz_sum = y_var + z_var # (y + z)

View File

@@ -18,7 +18,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int32, (UOp.const(dtypes.int, nmax),), expr) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n) def Range(n, nmax): return UOp.range(nmax, n)

View File

@@ -2,22 +2,20 @@
import unittest, pickle, functools, math import unittest, pickle, functools, math
import z3 import z3
from tinygrad.dtype import dtypes, ConstType from tinygrad.dtype import dtypes, ConstType, DType
from tinygrad.codegen import full_rewrite from tinygrad.codegen import full_rewrite
from tinygrad.codegen.late.devectorizer import sym
from tinygrad.helpers import Context from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
from tinygrad import Variable from tinygrad.uop.symbolic import sym
from tinygrad.uop.spec import uops_to_z3 from tinygrad.uop.spec import uops_to_z3
def render(self) -> tuple[str, ConstType, ConstType]: @track_rewrites(name="simplify symbolic uop")
# NOTE: we need STORE so the ALU op has children def render(v) -> UOp:
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) v_simplified = graph_rewrite(v, sym)
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()) return v_simplified
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
def uconst(val): return UOp.const(dtypes.int, val) def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype)
def uconst(val): return UOp.const(dtypes.index, val)
def usum(ops): return functools.reduce(lambda x,y: x+y, ops) def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
def uand(ops): return functools.reduce(lambda x,y: x*y, ops) def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
@@ -30,11 +28,12 @@ class TestSymbolicPickle(unittest.TestCase):
class TestSymbolic(unittest.TestCase): class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s, test_z3:bool=True): def helper_test_variable(self, v, n, m, s, test_z3:bool=True):
v_simplified = render(v)
if test_z3: if test_z3:
solver = z3.Solver() solver = z3.Solver()
expr, expr_simplified = uops_to_z3(solver, v, v.simplify()) expr, expr_simplified = uops_to_z3(solver, v, v_simplified)
self.assertEqual(solver.check(expr != expr_simplified), z3.unsat, "simplified expression not equal to original") self.assertEqual(solver.check(expr != expr_simplified), z3.unsat, "simplified expression not equal to original")
rendered, nmin, nmax = render(v) rendered, nmin, nmax = v_simplified.render(simplify=False), v_simplified.vmin, v_simplified.vmax
if isinstance(s, tuple): self.assertIn(rendered, s) if isinstance(s, tuple): self.assertIn(rendered, s)
else: self.assertEqual(rendered, s) else: self.assertEqual(rendered, s)
self.assertEqual(nmin, n) self.assertEqual(nmin, n)
@@ -111,7 +110,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
def test_xor_0(self): def test_xor_0(self):
self.helper_test_variable(Variable("a", 0, 8) ^ 0, 0, 8, "a") self.helper_test_variable(Variable("a", 0, 8, dtypes.int) ^ 0, 0, 8, "a")
def test_add_1(self): def test_add_1(self):
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)") self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)")
@@ -209,12 +208,12 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0)) self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0))
def test_range_div_its_symbolic_bound(self): def test_range_div_its_symbolic_bound(self):
a = Variable("a", 1, 10) a = Variable("a", 1, 10, dtypes.index)
ridx0 = UOp.range(a+2, 0) ridx0 = UOp.range(a+2, 0)
self.helper_test_variable(ridx0//(a+2), 0, 0, "0") self.helper_test_variable(ridx0//(a+2), 0, 0, "0")
def test_range_mod_its_symbolic_bound(self): def test_range_mod_its_symbolic_bound(self):
a = Variable("a", 1, 10) a = Variable("a", 1, 10, dtypes.index)
ridx = UOp.range(a+2, 0) ridx = UOp.range(a+2, 0)
self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0") self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0")
@@ -463,8 +462,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
def test_nest_div_negative_factor(self): def test_nest_div_negative_factor(self):
ridx0=UOp.variable("ridx0", 0, 9) ridx0=Variable("ridx0", 0, 9)
ridx1=UOp.variable("ridx1", 0, 6) ridx1=Variable("ridx1", 0, 6)
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)") self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)")
def test_div_into_mod(self): def test_div_into_mod(self):
@@ -533,8 +532,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(x//y, 2, 2, "2") self.helper_test_variable(x//y, 2, 2, "2")
self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))") self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))")
# ensure all 4 corners are checked # ensure all 4 corners are checked
x = Variable("x", -10, 10) x = Variable("x", -10, 10, dtypes.int)
y = Variable("y", -8, 9) y = Variable("y", -8, 9, dtypes.int)
self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)") self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)")
self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)") self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)")
@@ -587,6 +586,12 @@ class TestSymbolic(unittest.TestCase):
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
def test_arange_unrolled4_with_cast(self):
gidx = Variable("gidx", 0, 2559, dtypes.index)
dt = dtypes.int
unrolled_div = ((gidx+2561)//4 + 2).cast(dt)+((gidx+2562)//4).cast(dt)+((gidx+2560)//4).cast(dt)+((gidx+2559)//4).cast(dt)
self.helper_test_variable(unrolled_div, 2561, 5120, "((int)(gidx)+2561)")
def test_arange_unrolled4_mul(self): def test_arange_unrolled4_mul(self):
gidx = Variable("gidx", 0, 2559) gidx = Variable("gidx", 0, 2559)
unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4) unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4)
@@ -688,10 +693,10 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(-a<-b, False, True, "(b<a)") self.helper_test_variable(-a<-b, False, True, "(b<a)")
def test_where_cast(self): def test_where_cast(self):
s = Variable("s", 0, 3) s = Variable("s", 0, 3, dtypes.int)
cond = s < 2 cond = s < 2
a = Variable("a", 0, 3) a = Variable("a", 0, 3, dtypes.int)
b = Variable("b", 0, 3) b = Variable("b", 0, 3, dtypes.int)
expr = cond.where(a, b).cast(dtypes.half) expr = cond.where(a, b).cast(dtypes.half)
# TODO: copied from render, render does not support cast # TODO: copied from render, render does not support cast
@@ -709,6 +714,7 @@ class TestSymbolic(unittest.TestCase):
expr = cond1.where(cond2.where(a, b), b) expr = cond1.where(cond2.where(a, b), b)
self.helper_test_variable(expr, 0, 3, "(a if ((s<6)&(2<s)) else b)") self.helper_test_variable(expr, 0, 3, "(a if ((s<6)&(2<s)) else b)")
@unittest.expectedFailure # needs simplify_valid which is not in render anymore
def test_where_merge_branches2(self): def test_where_merge_branches2(self):
cond1 = Variable("s", 0, 10) < 5 cond1 = Variable("s", 0, 10) < 5
cond2 = Variable("s", 0, 10) < 6 cond2 = Variable("s", 0, 10) < 6
@@ -738,8 +744,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False) self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
def test_do_math_in_int32(self): def test_do_math_in_int32(self):
a = Variable("a", 1, 10) a = Variable("a", 1, 10, dtypes.int)
b = Variable("b", 1, 10) b = Variable("b", 1, 10, dtypes.int)
self.helper_test_variable(a.cast(dtypes.long)+b.cast(dtypes.long), 2, 20, "(long)((a+b))") self.helper_test_variable(a.cast(dtypes.long)+b.cast(dtypes.long), 2, 20, "(long)((a+b))")
self.helper_test_variable(a.cast(dtypes.long)*b.cast(dtypes.long), 1, 100, "(long)((a*b))") self.helper_test_variable(a.cast(dtypes.long)*b.cast(dtypes.long), 1, 100, "(long)((a*b))")

View File

@@ -142,7 +142,7 @@ class TestViz(BaseTestViz):
def test_const_node_visibility(self): def test_const_node_visibility(self):
a = UOp.variable("a", 0, 10) a = UOp.variable("a", 0, 10)
z = UOp.const(dtypes.int, 0) z = UOp.const(dtypes.index, 0)
alu = a*z alu = a*z
exec_rewrite(alu, [sym]) exec_rewrite(alu, [sym])
lst = get_viz_list() lst = get_viz_list()

View File

@@ -2,7 +2,7 @@ from typing import Any, Callable
import functools import functools
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
@@ -10,7 +10,7 @@ from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import pm_lowerer, get_index from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.codegen.quantize import pm_quant from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, cast_folding
from tinygrad.uop.decompositions import get_late_rewrite_patterns from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
@@ -93,6 +93,9 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
supported_ops = tuple(opts.code_for_op.keys()) supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# lower the index dtype to a concrete int
ret.append(RewriteStep(pm_lower_index_dtype+cast_folding+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
# optional pre matcher # optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher")) if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))

View File

@@ -34,7 +34,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
# try to split up dims: (a,) -> (b, c) # try to split up dims: (a,) -> (b, c)
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
if len(limited) < len(dims): if len(limited) < len(dims):
ret = [] ret = []
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}") if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, cast_folding
from tinygrad.helpers import getenv, flatten, AMX, prod, partition from tinygrad.helpers import getenv, flatten, AMX, prod, partition
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
@@ -57,6 +57,9 @@ load_store_indexing = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# index True is just Index # index True is just Index
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
# delete_redundant_gates (after expand) # delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates), UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
@@ -316,12 +319,12 @@ def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparent
pm_reduce_collapse = PatternMatcher([ pm_reduce_collapse = PatternMatcher([
# lift x+y out of reduce on lt # lift x+y out of reduce on lt
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None), ((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
# lift x*y out of reduce # lift x*y out of reduce
((UPat.var("x")*UPat.var("y")) < UPat.var("c"), ((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None), lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
# lift x+y out of reduce on ne # lift x+y out of reduce on ne
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None), ((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
# fold the range # fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
@@ -351,7 +354,7 @@ pm_reduce_collapse = PatternMatcher([
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
# index/load/where. TODO: this is more aggressive than needed # index/load/where. TODO: this is more aggressive than needed
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu), (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
])+sym ])+sym+cast_folding
def reduce_collapse(red:UOp): def reduce_collapse(red:UOp):
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))

View File

@@ -151,7 +151,7 @@ def fix_group_for_reduce(x:UOp):
pm_pre_expander = PatternMatcher([ pm_pre_expander = PatternMatcher([
# rewrite UPCAST/UNROLL range to something to be expanded # rewrite UPCAST/UNROLL range to something to be expanded
(UPat(Ops.RANGE, name="r"), (UPat(Ops.RANGE, name="r"),
lambda r: UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \ lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None), if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
# fix REDUCEs with UNROLLs # fix REDUCEs with UNROLLs
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll), (UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),

View File

@@ -298,6 +298,7 @@ class Compiled:
# TODO: move this to each Device # TODO: move this to each Device
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if dtype == dtypes.index: return False
if device is None: device = Device.DEFAULT if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16: if dtype == dtypes.bfloat16:
if device == "METAL": return not CI if device == "METAL": return not CI

View File

@@ -89,7 +89,7 @@ class dtypes:
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType) def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
@functools.cache @functools.cache
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + (dtypes.index,)
@staticmethod @staticmethod
@functools.cache @functools.cache
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
@@ -128,6 +128,7 @@ class dtypes:
@staticmethod @staticmethod
def fields() -> dict[str, DType]: return DTYPES_DICT def fields() -> dict[str, DType]: return DTYPES_DICT
void: Final[DType] = DType.new(-1, 0, "void", None) void: Final[DType] = DType.new(-1, 0, "void", None)
index: Final[DType] = DType.new(-1,100, "index", None)
bool: Final[DType] = DType.new(0, 1, "bool", '?') bool: Final[DType] = DType.new(0, 1, "bool", '?')
int8: Final[DType] = DType.new(1, 1, "signed char", 'b') int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
@@ -164,7 +165,7 @@ class dtypes:
uints = (uint8, uint16, uint32, uint64) uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64) sints = (int8, int16, int32, int64)
ints = uints + sints ints = uints + sints
all = floats + ints + (bool,) all = floats + ints + (bool, index)
if (env_default_float := getenv("DEFAULT_FLOAT", "")): if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower()) dtypes.default_float = getattr(dtypes, env_default_float.lower())
@@ -186,11 +187,12 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.cache @functools.cache
def least_upper_dtype(*ds:DType) -> DType: def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \
if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))} DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))}
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"} INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
@functools.cache @functools.cache
def can_safe_cast(dt0:DType, dt1:DType) -> bool: def can_safe_cast(dt0:DType, dt1:DType) -> bool:
@@ -198,6 +200,7 @@ def can_safe_cast(dt0:DType, dt1:DType) -> bool:
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
if dt0 == dt1 or dt0 == dtypes.bool: return True if dt0 == dt1 or dt0 == dtypes.bool: return True
match dt1: match dt1:
case dtypes.index: return dt0 in dtypes.ints
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16,
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
@@ -315,4 +318,4 @@ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-de
except TypeError: return None except TypeError: return None
@functools.cache @functools.cache
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype] return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype]

View File

@@ -119,7 +119,7 @@ def map_reshape(idx:UOp, r:UOp):
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]: for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src) to_sum.append(acc*src)
acc *= s acc *= s
mish = sum(to_sum, start=UOp.const(dtypes.int, 0)) mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
ret:list[UOp] = [] ret:list[UOp] = []
for s in r.src[0].shape[::-1]: for s in r.src[0].shape[::-1]:
ret.append(mish % s) # NOTE: simplify will turn this to CONST ret.append(mish % s) # NOTE: simplify will turn this to CONST
@@ -186,7 +186,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
ranges.append(idx.src[1+i]) ranges.append(idx.src[1+i])
continue continue
passthrough_idx.append(idx.src[1+i]) passthrough_idx.append(idx.src[1+i])
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
new_ranges.append(ranges[-1]) new_ranges.append(ranges[-1])
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device) ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device)
return ret.index(*passthrough_idx) return ret.index(*passthrough_idx)
@@ -195,7 +195,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
if x.arg is not None: return None if x.arg is not None: return None
ranges = [] ranges = []
for s in x.shape[len(x.src)-1:]: for s in x.shape[len(x.src)-1:]:
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device) ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape) return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape)

View File

@@ -5,22 +5,8 @@ import functools
from typing import Callable from typing import Callable
from tinygrad.helpers import merge_dicts, getenv from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, unravel from tinygrad.shape.view import View, unravel
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
def handle_upcast(u: UOp) -> UOp|None:
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
# check for overflow, upcast this to int64
if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int):
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src]))
# if any inputs are int64 and this *doesn't* overflow, cast back to int
if any(x.dtype == dtypes.int64 for x in u.src):
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype)
return None
pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),])
@functools.cache @functools.cache
def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]: def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
@@ -34,8 +20,8 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=No
# simplify # simplify
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
# symbolic again, upcast if needed # symbolic again
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src return graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 2").src
@functools.cache @functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:

View File

@@ -204,15 +204,15 @@ class View:
# Merge dimensions in vm2 if required. # Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1, dtypes.index) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, UOp.const(dtypes.int, 0) merged_size, merged_term = 1, UOp.const(dtypes.index, 0)
extents: list[tuple[sint, UOp]] = [] extents: list[tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
merged_size *= s merged_size *= s
if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False): if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False):
extents.append((merged_size, merged_term)) extents.append((merged_size, merged_term))
merged_size, merged_term = 1, UOp.const(dtypes.int, 0) merged_size, merged_term = 1, UOp.const(dtypes.index, 0)
if resolve(merged_term != 0): return None if resolve(merged_term != 0): return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
from tinygrad.gradient import compute_gradient from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int
from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.uop.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
@@ -139,6 +139,8 @@ class Tensor(MathTrait):
# create a UOp from the different types of inputs # create a UOp from the different types of inputs
if isinstance(data, UOp): if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype==dtypes.index: data = index_to_concrete_int(data)
if data.op is Ops.BIND: if data.op is Ops.BIND:
var, val = data.unbind() var, val = data.unbind()
# give the bound constant a device # give the bound constant a device

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -307,7 +307,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def range(end:sint, *arg): def range(end:sint, *arg):
if len(arg) == 0: raise RuntimeError("range needs an arg") if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,) if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtypes.int, src=(sint_to_uop(end),), arg=arg) return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg)
def r(self, op:Ops, axis:tuple[int, ...]): def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self if len(axis) == 0: return self
@@ -482,7 +482,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop Variable stuff *** # *** uop Variable stuff ***
@staticmethod @staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int) -> UOp: def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp:
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property @property
@@ -573,7 +573,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
if self.op is Ops.GEP: return self.src[0]._min_max if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
if self.op is Ops.CAST and self.dtype in (dtypes.floats+dtypes.sints): if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype)) return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
return dtypes.min(self.dtype), dtypes.max(self.dtype) return dtypes.min(self.dtype), dtypes.max(self.dtype)
@@ -1025,7 +1025,30 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
for k,v in input_map.items(): new_map[k] = new_map.get(v,v) for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map return new_map
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))),
# comparison ops might now have different dtypes in their sources
(UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None),
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y")), name="u"), lambda cond,u,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u))),
(UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end:
r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))),
(UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x),
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(
dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))),
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src)
else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))),
(UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)),
(UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)),
])
def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype)
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])

View File

@@ -28,17 +28,17 @@ try:
# float loads only become a variable when they get cast to int/bool # float loads only become a variable when they get cast to int/bool
(UPat(Ops.LOAD, dtypes.ints, name="x"), (UPat(Ops.LOAD, dtypes.ints, name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"), (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
# z3 can cast from bool to int automatically # z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
# if the source of the cast is not a noop it means that it is a float and so we create a new variable # if the source of the cast is not a noop it means that it is a float and so we create a new variable
(UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx: (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx: (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
(UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"), (UPat(Ops.XOR, dtype=dtypes.ints+(dtypes.bool, ), src=UPat(Ops.NOOP), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg[1], x.dtype.itemsize*8) for s in x.src)))))), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg[1], x.dtype.itemsize*8) for s in x.src)))))),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))), (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
# A comparison between floats introduces a new bool variable # A comparison between floats introduces a new bool variable
@@ -95,7 +95,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}), (UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
# Tensor variable bindings # Tensor variable bindings
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
# Tensor const has a device and an unmasked ShapeTracker of stride 0 # Tensor const has a device and an unmasked ShapeTracker of stride 0
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum # NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
@@ -120,6 +120,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# ***** uop type spec ***** # ***** uop type spec *****
def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)): def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
# TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
@@ -175,6 +176,9 @@ spec = PatternMatcher([
# **** new style load/store **** # **** new style load/store ****
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
# INDEX is used in new style load/store # INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?> # INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True), (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),

View File

@@ -26,7 +26,7 @@ symbolic_simple = PatternMatcher([
# ** self folding ** # ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x
(UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
(UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
@@ -49,11 +49,11 @@ symbolic_simple = PatternMatcher([
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x),
# ** zero folding ** # ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# x*0 -> 0 or 0*x -> 0 # x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value. # if x is nan or inf it should render the nan value.
@@ -108,6 +108,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
if fac!=1: if fac!=1:
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
u = u.src[0] u = u.src[0]
if u.op is Ops.CAST and u.src[0].dtype == dtypes.index: u = u.src[0]
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
if denominator != u.src[1].arg: return None if denominator != u.src[1].arg: return None
if (s0:=u.src[0]).vmin < 0: return None if (s0:=u.src[0]).vmin < 0: return None
@@ -123,7 +124,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
for i in range(denominator-len(seen_const)): for i in range(denominator-len(seen_const)):
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
if sorted(seen_const)==list(range(denominator)): if sorted(seen_const)==list(range(denominator)):
return fac*ans return (fac*ans).cast(divs.dtype)
return None return None
def lt_folding(x:UOp, c:int) -> UOp|None: def lt_folding(x:UOp, c:int) -> UOp|None:
@@ -270,10 +271,19 @@ gep_pushing = PatternMatcher([
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma), (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
]) ])
cast_folding = PatternMatcher([
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
# try to do math in int instead of long
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
])
commutative = PatternMatcher([ commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for ints) ** # ** COMMUTATIVE flipping (only for index) **
# NOTE: this can break merging vector math by only flipping some of them # NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
]) ])
symbolic = symbolic_simple+commutative+PatternMatcher([ symbolic = symbolic_simple+commutative+PatternMatcher([
@@ -289,10 +299,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2), ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3) ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
(UPat.var('x', dtypes.ints).cast(dtypes.ints, name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
(UPat(GroupOp.Binary, src=(UPat.var("x",dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
# a conditional with the same results either way is a noop, also fold const conditionals # a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
@@ -317,27 +323,27 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
# ** lt ** # ** lt **
# c0*x<c1 for positive int c0,c1 # c0*x<c1 for positive int c0,c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False), ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None), lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1 # c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False), ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None), lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//d<c # x//d<c
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False), ((UPat.var("x", dtype=dtypes.index)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic *** # *** rules from symbolic ***
# unrolled arange div folding # unrolled arange div folding
((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)), ((UPat()+(UPat()//UPat.cvar("d", vec=False)).or_casted()).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)), ((UPat()+((UPat()//UPat.cvar("d", vec=False)).or_casted()*UPat.cvar("c"))).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
# generic lt folding # generic lt folding
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), (UPat.var("x", dtypes.index)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*-1, lambda x,y: y<x), (UPat.var("x", dtypes.index)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
# canonicalize a simplex with positive coefficients > 0 # canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0 # not x < 1 -> X > 0
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
# ** div ** # ** div **
# div folding # div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
@@ -345,30 +351,28 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# a range mod its own upper bound is just the range # a range mod its own upper bound is just the range
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
(UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), (UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
(UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), (UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder), (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), (UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None), (UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# ** mod ** # ** mod **
# mod folding # mod folding
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), (UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
(UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), (UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
# up + x//c*c + x%c ])+gep_pushing+cast_folding
(UPat.var("up") + UPat.var("x", dtypes.ints)//UPat.cvar("c")*UPat.cvar("c") + UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda up,x,c: up+x),
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([ symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) ** # ** combine terms (opinionated) **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
]) ])
# ******** we take a small aside to "simplify_valid" to rewrite valids ******** # ******** we take a small aside to "simplify_valid" to rewrite valids ********
@@ -400,6 +404,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# simplify uop given that valid is True # simplify uop given that valid is True
for expr,v in bounds.items(): for expr,v in bounds.items():
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# some expr has lower bound > upper bound -> valid is an empty set and we return None # some expr has lower bound > upper bound -> valid is an empty set and we return None
if v0 > v1: return None if v0 > v1: return None
# whole node became a const # whole node became a const