From cb7c76a3bd99b2f4152a7c1cc47cd04a893de7b4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 1 Jan 2026 15:09:58 -0500 Subject: [PATCH] update test_fuzz_failure to not contruct full UOp (#13960) --- test/unit/test_symbolic_failures.py | 80 +++-------------------------- 1 file changed, 6 insertions(+), 74 deletions(-) diff --git a/test/unit/test_symbolic_failures.py b/test/unit/test_symbolic_failures.py index ce8a2016b3..8587bf2659 100644 --- a/test/unit/test_symbolic_failures.py +++ b/test/unit/test_symbolic_failures.py @@ -1,7 +1,6 @@ import unittest -from tinygrad import Variable, dtypes +from tinygrad import Variable from tinygrad.helpers import Context -from tinygrad.uop.ops import Ops, UOp class TestFuzzFailure(unittest.TestCase): @@ -108,54 +107,9 @@ class TestFuzzFailure(unittest.TestCase): v1=Variable("v1", 0, 256) v2=Variable("v2", 0, 32) v3=Variable("v3", 0, 32) - expr = UOp(Ops.MUL, dtypes.int, arg=None, src=( - UOp(Ops.MAX, dtypes.int, arg=None, src=( - UOp(Ops.MUL, dtypes.int, arg=None, src=( - UOp(Ops.WHERE, dtypes.int, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( - x5:=UOp(Ops.IDIV, dtypes.int, arg=None, src=( - UOp(Ops.WHERE, dtypes.int, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( - x9:=UOp(Ops.CONST, dtypes.int, arg=9, src=()), - x10:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v1', 0, 256), src=()),)), - x11:=UOp(Ops.CONST, dtypes.bool, arg=True, src=()),)), - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.MUL, dtypes.int, arg=None, src=( - x10, - x14:=UOp(Ops.CONST, dtypes.int, arg=-4, src=()),)), - x14,)), - UOp(Ops.IDIV, dtypes.int, arg=None, src=( - x10, - x9,)),)), - x9,)), - x14,)), - x11,)), - x5, - UOp(Ops.IDIV, dtypes.int, arg=None, src=( - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.MOD, dtypes.int, arg=None, src=( - x19:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v2', 0, 32), src=()), - UOp(Ops.CONST, dtypes.int, arg=3, src=()),)), - x19,)), - UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),)), - x22:=UOp(Ops.CONST, dtypes.int, arg=-1, src=()),)), - UOp(Ops.MUL, dtypes.int, arg=None, src=( - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.MOD, dtypes.int, arg=None, src=( - UOp(Ops.MUL, dtypes.int, arg=None, src=( - x10, - UOp(Ops.CONST, dtypes.int, arg=-2, src=()),)), - UOp(Ops.CONST, dtypes.int, arg=6, src=()),)), - UOp(Ops.MOD, dtypes.int, arg=None, src=( - UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v3', 0, 32), src=()), - UOp(Ops.CONST, dtypes.int, arg=1, src=()),)),)), - UOp(Ops.CONST, dtypes.int, arg=0, src=()),)), - x22,)),)), - x22,)) - v1_val, v2_val, v3_val = UOp.const(dtypes.int, 9), UOp.const(dtypes.int, 0),UOp.const(dtypes.int, 0) + x5 = (v1 <= 9).where(v1 * -4 - 4, v1 // 9) // 9 + expr = ((x5 >= -4).where(x5, (v2 % 3 + v2) // 5) * -1).maximum(((v1 * -2) % 6 + v3 % 1) * -1) * -1 + v1_val, v2_val, v3_val = v1.const_like(9), v2.const_like(0), v3.const_like(0) num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() self.assertEqual(num, rn) @@ -164,30 +118,8 @@ class TestFuzzFailure(unittest.TestCase): v1=Variable("v1", 0, 16) v2=Variable("v2", 0, 128) v3=Variable("v3", 0, 5) - expr = UOp(Ops.MOD, dtypes.index, arg=None, src=( - UOp(Ops.ADD, dtypes.index, arg=None, src=( - UOp(Ops.MOD, dtypes.index, arg=None, src=( - UOp(Ops.ADD, dtypes.index, arg=None, src=( - UOp(Ops.MAX, dtypes.index, arg=None, src=( - UOp(Ops.MUL, dtypes.index, arg=None, src=( - x5:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v2', 0, 128), src=()), - UOp(Ops.CONST, dtypes.index, arg=0, src=()),)), - UOp(Ops.CONST, dtypes.index, arg=8, src=()),)), - UOp(Ops.MUL, dtypes.index, arg=None, src=( - x5, - UOp(Ops.CONST, dtypes.index, arg=-2, src=()),)),)), - x10:=UOp(Ops.CONST, dtypes.index, arg=5, src=()),)), - UOp(Ops.ADD, dtypes.index, arg=None, src=( - UOp(Ops.ADD, dtypes.index, arg=None, src=( - UOp(Ops.IDIV, dtypes.index, arg=None, src=( - x14:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v1', 0, 16), src=()), - UOp(Ops.CONST, dtypes.index, arg=6, src=()),)), - UOp(Ops.CONST, dtypes.index, arg=4, src=()),)), - UOp(Ops.ADD, dtypes.index, arg=None, src=( - x14, - UOp(Ops.CONST, dtypes.index, arg=1, src=()),)),)),)), - x10,)) - v1_val, v2_val, v3_val = UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 7),UOp.const(dtypes.int, 0) + expr = (((v2 * 0).maximum(8) - v2 * 2) % 5 + v1 // 6 + v1 + 5) % 5 + v1_val, v2_val, v3_val = v1.const_like(0), v2.const_like(7), v3.const_like(0) num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() self.assertEqual(num, rn)