mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update test_uop_symbolic to test UOp min and max (#5764)
covers #5750, #5748, #5741
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from typing import Tuple
|
||||
#from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
# TODO: fix all the @unittest.expectedFailure
|
||||
@@ -7,13 +8,13 @@ import unittest, pickle
|
||||
# *** fake symobilc uops ***
|
||||
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.ops import BinaryOps
|
||||
import functools
|
||||
|
||||
def render(self) -> str:
|
||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(0,True))
|
||||
graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
|
||||
@@ -22,8 +23,9 @@ def render(self) -> str:
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
class TestRenderer(CStyleLanguage):
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
rewritten_uop = [uop for uop in graph.uops if uop.op is UOps.STORE][0].src[-1]
|
||||
fxn = TestRenderer().render("", graph)
|
||||
return fxn.split("data0[0] = ")[1].split(";")[0]
|
||||
return fxn.split("data0[0] = ")[1].split(";")[0], rewritten_uop.vmin.arg, rewritten_uop.vmax.arg
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
def Variable(expr, nmin, nmax):
|
||||
@@ -52,12 +54,13 @@ class TestSymbolicPickle(unittest.TestCase):
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
rendered, nmin, nmax = render(v)
|
||||
if isinstance(s, set):
|
||||
self.assertIn(render(v), s)
|
||||
self.assertIn(rendered, s)
|
||||
else:
|
||||
self.assertEqual(render(v), s)
|
||||
#self.assertEqual(v.min, n)
|
||||
#self.assertEqual(v.max, m)
|
||||
self.assertEqual(rendered, s)
|
||||
self.assertEqual(nmin, n)
|
||||
self.assertEqual(nmax, m)
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
|
||||
Reference in New Issue
Block a user