update test_uop_symbolic to test UOp min and max (#5764)

covers #5750, #5748, #5741
This commit is contained in:
chenyu
2024-07-27 19:53:21 -04:00
committed by GitHub
parent 1903542c2d
commit 80c6475757

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python
import unittest, pickle
from typing import Tuple
#from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
# TODO: fix all the @unittest.expectedFailure
@@ -7,13 +8,13 @@ import unittest, pickle
# *** fake symobilc uops ***
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.ops import BinaryOps
import functools
def render(self) -> str:
def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(0,True))
graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
@@ -22,8 +23,9 @@ def render(self) -> str:
from tinygrad.renderer.cstyle import CStyleLanguage
class TestRenderer(CStyleLanguage):
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
rewritten_uop = [uop for uop in graph.uops if uop.op is UOps.STORE][0].src[-1]
fxn = TestRenderer().render("", graph)
return fxn.split("data0[0] = ")[1].split(";")[0]
return fxn.split("data0[0] = ")[1].split(";")[0], rewritten_uop.vmin.arg, rewritten_uop.vmax.arg
def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax):
@@ -52,12 +54,13 @@ class TestSymbolicPickle(unittest.TestCase):
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
rendered, nmin, nmax = render(v)
if isinstance(s, set):
self.assertIn(render(v), s)
self.assertIn(rendered, s)
else:
self.assertEqual(render(v), s)
#self.assertEqual(v.min, n)
#self.assertEqual(v.max, m)
self.assertEqual(rendered, s)
self.assertEqual(nmin, n)
self.assertEqual(nmax, m)
def test_cmp_simple(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")