add rendering to index (#12338)

This commit is contained in:
George Hotz
2025-09-30 09:18:05 +08:00
committed by GitHub
parent baf3b60cfb
commit cdfa0f29fd
5 changed files with 24 additions and 33 deletions

View File

@@ -814,20 +814,6 @@ class TestShapeTrackerSize(unittest.TestCase):
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
self.assertEqual(st.real_size(), 100)
class TestRender(unittest.TestCase):
def test_render(self):
st = ShapeTracker.from_shape((2, 3))
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "True")
st = st.pad(((0, 1), (0, 0)))
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "(ridx0<2)")
class TestVariableShrink(unittest.TestCase):
def test_shrink(self):
st = ShapeTracker.from_shape((10,))

View File

@@ -71,8 +71,8 @@ class TestValidIdxSimplification(unittest.TestCase):
idx = ridx0+ridx1+ridx2+ridx3
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((ridx0+ridx1)+ridx2)+ridx3)",
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
"(((r0+r1)+r2)+r3)",
"((((r0*3)+r1)<8)&((((r2*3)+r3)%4)<2))")
def test_simplify_within_valid2(self):
gidx0 = Special("gidx0", 56)
@@ -85,8 +85,8 @@ class TestValidIdxSimplification(unittest.TestCase):
ridx0 = Range(0, 2)
v0 = ridx0<1
v1 = ((ridx0*5+1)%6)<5
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v0&v1).render(), "(r0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(r0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
@@ -128,8 +128,8 @@ class TestValidIdxSimplification(unittest.TestCase):
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((ridx0*2)+(ridx3*-1))+1)",
"(ridx2<1)")
"(((r0*2)+(r3*-1))+1)",
"(r2<1)")
def test_load_in_valid(self):
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
@@ -154,8 +154,8 @@ class TestValidIdxSimplification(unittest.TestCase):
valid = (ridx2<1)&(ridx1<6)
load = get_gated_load_uop(valid, idx)
self.check(load,
"(ridx0*1568)",
"((ridx2<1)&(ridx1<6))")
"(r0*1568)",
"((r2<1)&(r1<6))")
def test_valid_becomes_const1_z3(self):
from z3 import Ints, Solver, And, If, Not, unsat
@@ -195,7 +195,7 @@ class TestValidIdxSimplification(unittest.TestCase):
load = get_gated_load_uop(valid, idx)
self.check(load,
"1",
"((((ridx0+ridx1)<1)!=True)&(((ridx2+ridx3)<1)!=True))")
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
def test_valid_with_non_const_rhs(self):
ridx0 = Range(0, 2**16)
@@ -205,8 +205,8 @@ class TestValidIdxSimplification(unittest.TestCase):
idx = ridx0%1024
load = get_gated_load_uop(valid, idx)
self.check(load,
"ridx0",
"(ridx0<((ridx1*4)+ridx2))")
"r0",
"(r0<((r1*4)+r2))")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
@@ -304,7 +304,7 @@ class TestImageSimplification(unittest.TestCase):
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*48)+(ridx2*6))+ridx0)+-6)", "(((idx2*2)+ridx1)+-1)")
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
@@ -325,7 +325,7 @@ class TestImageSimplification(unittest.TestCase):
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*24)+(ridx2*3))+ridx0)+-3)", "(((idx2*2)+ridx1)+-1)")
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
@@ -346,9 +346,9 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop(shape, valid, idx)
self.check(load,
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
"(((idx0+((idx1*512)+(r1*64)))+832)%1024)",
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
@@ -424,7 +424,7 @@ class TestImageSimplification(unittest.TestCase):
alu1 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)//768
valid = (((idx2+ridx4)<1)!=1)&(((idx1+ridx5)<1)!=1)
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
self.check(load, None, "((((idx1*24)+ridx3)+(ridx5*3))+-3)", "(((idx2*2)+ridx4)+-1)")
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")
if __name__ == '__main__':
unittest.main()

View File

@@ -246,7 +246,7 @@ class TestSymbolic(unittest.TestCase):
def test_range_mod_its_symbolic_bound(self):
a = Variable("a", 1, 10, dtypes.index)
ridx = UOp.range(a+2, 0)
self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0")
self.helper_test_variable(ridx%(a+2), 0, 11, "r0")
def test_div_min_max(self):
self.helper_test_variable(Variable("a", 2, 7) // 2, 1, 3, "(a//2)")

View File

@@ -8,6 +8,7 @@ from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC
from tinygrad.helpers import strip_parens
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, MultiBuffer
@@ -1102,7 +1103,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(UPat((Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg)),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}" if x.arg[0] >= 0 else f"ridxm{-x.arg[0]}")),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"r{x.arg[0]}" if x.arg[0] >= 0 else f"rm{-x.arg[0]}")),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
@@ -1115,6 +1116,8 @@ renderer = PatternMatcher([
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(set(syms.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
(UPat(Ops.INDEX, name="x"), lambda x:
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
])
renderer_infer = PatternMatcher([
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),

View File

@@ -83,6 +83,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += f"\n{shape_to_str(u.shape)}"
elif len(rngs:=u.ranges):
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
if u.op is Ops.INDEX:
label += f"\n{u.render()}"
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"