test pyrender (#12005)

* test pyrender

* make them print

* switch to pyrendered
This commit is contained in:
George Hotz
2025-09-04 11:48:40 -07:00
committed by GitHub
parent 560df206cc
commit 70ce29b630
3 changed files with 36 additions and 16 deletions

20
extra/test_pyrender.py Normal file
View File

@@ -0,0 +1,20 @@
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.helpers import tqdm
from tinygrad.uop.ops import pyrender, UOp, Ops
from tinygrad import dtypes
from tinygrad.shape.shapetracker import ShapeTracker, View
inf, nan = float('inf'), float('nan')
if __name__ == "__main__":
ast_strs = load_worlds()
for i, ast_str in enumerate(tqdm(ast_strs)):
good_ast = ast_str_to_ast(ast_str)
code = '\n'.join(pyrender(good_ast))
print("\n***************\n\n"+code)
exec(code)
if str(good_ast) != str(ast):
print(code)
print("MISMATCH")
print(good_ast)
print(ast)
break

View File

@@ -455,13 +455,13 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
Tensor.manual_seed(0)
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()),)),)),)),))
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [
Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)
@@ -474,13 +474,13 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts_with_gep(self):
Tensor.manual_seed(0)
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()),)),)),)),))
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]

View File

@@ -1053,13 +1053,13 @@ renderer_infer = PatternMatcher([
])
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2"}
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
pm_pyrender = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE}, src=UPat(Ops.NOOP), name="x"),
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x:
UOp(Ops.NOOP, arg=f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])})")),