update test_fuzz_failure to not contruct full UOp (#13960)

This commit is contained in:
chenyu
2026-01-01 15:09:58 -05:00
committed by GitHub
parent 51398edf9c
commit cb7c76a3bd

View File

@@ -1,7 +1,6 @@
import unittest
from tinygrad import Variable, dtypes
from tinygrad import Variable
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops, UOp
class TestFuzzFailure(unittest.TestCase):
@@ -108,54 +107,9 @@ class TestFuzzFailure(unittest.TestCase):
v1=Variable("v1", 0, 256)
v2=Variable("v2", 0, 32)
v3=Variable("v3", 0, 32)
expr = UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.MAX, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x5:=UOp(Ops.IDIV, dtypes.int, arg=None, src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x9:=UOp(Ops.CONST, dtypes.int, arg=9, src=()),
x10:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v1', 0, 256), src=()),)),
x11:=UOp(Ops.CONST, dtypes.bool, arg=True, src=()),)),
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
x10,
x14:=UOp(Ops.CONST, dtypes.int, arg=-4, src=()),)),
x14,)),
UOp(Ops.IDIV, dtypes.int, arg=None, src=(
x10,
x9,)),)),
x9,)),
x14,)),
x11,)),
x5,
UOp(Ops.IDIV, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MOD, dtypes.int, arg=None, src=(
x19:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v2', 0, 32), src=()),
UOp(Ops.CONST, dtypes.int, arg=3, src=()),)),
x19,)),
UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),)),
x22:=UOp(Ops.CONST, dtypes.int, arg=-1, src=()),)),
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MOD, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
x10,
UOp(Ops.CONST, dtypes.int, arg=-2, src=()),)),
UOp(Ops.CONST, dtypes.int, arg=6, src=()),)),
UOp(Ops.MOD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v3', 0, 32), src=()),
UOp(Ops.CONST, dtypes.int, arg=1, src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),
x22,)),)),
x22,))
v1_val, v2_val, v3_val = UOp.const(dtypes.int, 9), UOp.const(dtypes.int, 0),UOp.const(dtypes.int, 0)
x5 = (v1 <= 9).where(v1 * -4 - 4, v1 // 9) // 9
expr = ((x5 >= -4).where(x5, (v2 % 3 + v2) // 5) * -1).maximum(((v1 * -2) % 6 + v3 % 1) * -1) * -1
v1_val, v2_val, v3_val = v1.const_like(9), v2.const_like(0), v3.const_like(0)
num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
self.assertEqual(num, rn)
@@ -164,30 +118,8 @@ class TestFuzzFailure(unittest.TestCase):
v1=Variable("v1", 0, 16)
v2=Variable("v2", 0, 128)
v3=Variable("v3", 0, 5)
expr = UOp(Ops.MOD, dtypes.index, arg=None, src=(
UOp(Ops.ADD, dtypes.index, arg=None, src=(
UOp(Ops.MOD, dtypes.index, arg=None, src=(
UOp(Ops.ADD, dtypes.index, arg=None, src=(
UOp(Ops.MAX, dtypes.index, arg=None, src=(
UOp(Ops.MUL, dtypes.index, arg=None, src=(
x5:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v2', 0, 128), src=()),
UOp(Ops.CONST, dtypes.index, arg=0, src=()),)),
UOp(Ops.CONST, dtypes.index, arg=8, src=()),)),
UOp(Ops.MUL, dtypes.index, arg=None, src=(
x5,
UOp(Ops.CONST, dtypes.index, arg=-2, src=()),)),)),
x10:=UOp(Ops.CONST, dtypes.index, arg=5, src=()),)),
UOp(Ops.ADD, dtypes.index, arg=None, src=(
UOp(Ops.ADD, dtypes.index, arg=None, src=(
UOp(Ops.IDIV, dtypes.index, arg=None, src=(
x14:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v1', 0, 16), src=()),
UOp(Ops.CONST, dtypes.index, arg=6, src=()),)),
UOp(Ops.CONST, dtypes.index, arg=4, src=()),)),
UOp(Ops.ADD, dtypes.index, arg=None, src=(
x14,
UOp(Ops.CONST, dtypes.index, arg=1, src=()),)),)),)),
x10,))
v1_val, v2_val, v3_val = UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 7),UOp.const(dtypes.int, 0)
expr = (((v2 * 0).maximum(8) - v2 * 2) % 5 + v1 // 6 + v1 + 5) % 5
v1_val, v2_val, v3_val = v1.const_like(0), v2.const_like(7), v3.const_like(0)
num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
self.assertEqual(num, rn)