mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test pyrender (#12005)
* test pyrender * make them print * switch to pyrendered
This commit is contained in:
20
extra/test_pyrender.py
Normal file
20
extra/test_pyrender.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.helpers import tqdm
|
||||
from tinygrad.uop.ops import pyrender, UOp, Ops
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
for i, ast_str in enumerate(tqdm(ast_strs)):
|
||||
good_ast = ast_str_to_ast(ast_str)
|
||||
code = '\n'.join(pyrender(good_ast))
|
||||
print("\n***************\n\n"+code)
|
||||
exec(code)
|
||||
if str(good_ast) != str(ast):
|
||||
print(code)
|
||||
print("MISMATCH")
|
||||
print(good_ast)
|
||||
print(ast)
|
||||
break
|
||||
@@ -455,13 +455,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()),)),)),)),))
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=())
|
||||
c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c4 = c3.load()
|
||||
c5 = c1.store(c4)
|
||||
ast = c5.sink()
|
||||
opt = [
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
|
||||
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)
|
||||
@@ -474,13 +474,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()),)),)),)),))
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=())
|
||||
c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c4 = c3.load()
|
||||
c5 = c1.store(c4)
|
||||
ast = c5.sink()
|
||||
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
|
||||
@@ -1053,13 +1053,13 @@ renderer_infer = PatternMatcher([
|
||||
])
|
||||
|
||||
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2"}
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
|
||||
pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE}, src=UPat(Ops.NOOP), name="x"),
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])})")),
|
||||
|
||||
Reference in New Issue
Block a user