mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add rendering to index (#12338)
This commit is contained in:
@@ -814,20 +814,6 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
class TestRender(unittest.TestCase):
|
||||
def test_render(self):
|
||||
st = ShapeTracker.from_shape((2, 3))
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "True")
|
||||
|
||||
st = st.pad(((0, 1), (0, 0)))
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
class TestVariableShrink(unittest.TestCase):
|
||||
def test_shrink(self):
|
||||
st = ShapeTracker.from_shape((10,))
|
||||
|
||||
@@ -71,8 +71,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
idx = ridx0+ridx1+ridx2+ridx3
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(((ridx0+ridx1)+ridx2)+ridx3)",
|
||||
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
|
||||
"(((r0+r1)+r2)+r3)",
|
||||
"((((r0*3)+r1)<8)&((((r2*3)+r3)%4)<2))")
|
||||
|
||||
def test_simplify_within_valid2(self):
|
||||
gidx0 = Special("gidx0", 56)
|
||||
@@ -85,8 +85,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
ridx0 = Range(0, 2)
|
||||
v0 = ridx0<1
|
||||
v1 = ((ridx0*5+1)%6)<5
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(r0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(r0<1)")
|
||||
|
||||
def test_valid_order_matters2(self):
|
||||
gidx0 = Special("gidx0", 13)
|
||||
@@ -128,8 +128,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(((ridx0*2)+(ridx3*-1))+1)",
|
||||
"(ridx2<1)")
|
||||
"(((r0*2)+(r3*-1))+1)",
|
||||
"(r2<1)")
|
||||
|
||||
def test_load_in_valid(self):
|
||||
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
|
||||
@@ -154,8 +154,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = (ridx2<1)&(ridx1<6)
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(ridx0*1568)",
|
||||
"((ridx2<1)&(ridx1<6))")
|
||||
"(r0*1568)",
|
||||
"((r2<1)&(r1<6))")
|
||||
|
||||
def test_valid_becomes_const1_z3(self):
|
||||
from z3 import Ints, Solver, And, If, Not, unsat
|
||||
@@ -195,7 +195,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"1",
|
||||
"((((ridx0+ridx1)<1)!=True)&(((ridx2+ridx3)<1)!=True))")
|
||||
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
|
||||
|
||||
def test_valid_with_non_const_rhs(self):
|
||||
ridx0 = Range(0, 2**16)
|
||||
@@ -205,8 +205,8 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
idx = ridx0%1024
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"ridx0",
|
||||
"(ridx0<((ridx1*4)+ridx2))")
|
||||
"r0",
|
||||
"(r0<((r1*4)+r2))")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
@@ -304,7 +304,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.check(load, None, "((((idx1*48)+(ridx2*6))+ridx0)+-6)", "(((idx2*2)+ridx1)+-1)")
|
||||
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
|
||||
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
@@ -325,7 +325,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.check(load, None, "((((idx1*24)+(ridx2*3))+ridx0)+-3)", "(((idx2*2)+ridx1)+-1)")
|
||||
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
|
||||
|
||||
def test_openpilot_conv3(self):
|
||||
# in openpilot 0.9.7
|
||||
@@ -346,9 +346,9 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.check(load,
|
||||
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
|
||||
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
|
||||
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(r1*64)))+832)%1024)",
|
||||
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
@@ -424,7 +424,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu1 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)//768
|
||||
valid = (((idx2+ridx4)<1)!=1)&(((idx1+ridx5)<1)!=1)
|
||||
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
|
||||
self.check(load, None, "((((idx1*24)+ridx3)+(ridx5*3))+-3)", "(((idx2*2)+ridx4)+-1)")
|
||||
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -246,7 +246,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_range_mod_its_symbolic_bound(self):
|
||||
a = Variable("a", 1, 10, dtypes.index)
|
||||
ridx = UOp.range(a+2, 0)
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0")
|
||||
self.helper_test_variable(ridx%(a+2), 0, 11, "r0")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 2, 7) // 2, 1, 3, "(a//2)")
|
||||
|
||||
@@ -8,6 +8,7 @@ from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC
|
||||
from tinygrad.helpers import strip_parens
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
@@ -1102,7 +1103,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg)),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}" if x.arg[0] >= 0 else f"ridxm{-x.arg[0]}")),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"r{x.arg[0]}" if x.arg[0] >= 0 else f"rm{-x.arg[0]}")),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
|
||||
@@ -1115,6 +1116,8 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(set(syms.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
|
||||
(UPat(Ops.INDEX, name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
|
||||
])
|
||||
renderer_infer = PatternMatcher([
|
||||
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),
|
||||
|
||||
@@ -83,6 +83,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
elif len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
if u.op is Ops.INDEX:
|
||||
label += f"\n{u.render()}"
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
|
||||
Reference in New Issue
Block a user