mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
update test_fuzz_failure to not contruct full UOp (#13960)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Variable, dtypes
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
|
||||
|
||||
class TestFuzzFailure(unittest.TestCase):
|
||||
@@ -108,54 +107,9 @@ class TestFuzzFailure(unittest.TestCase):
|
||||
v1=Variable("v1", 0, 256)
|
||||
v2=Variable("v2", 0, 32)
|
||||
v3=Variable("v3", 0, 32)
|
||||
expr = UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
x5:=UOp(Ops.IDIV, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
x9:=UOp(Ops.CONST, dtypes.int, arg=9, src=()),
|
||||
x10:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v1', 0, 256), src=()),)),
|
||||
x11:=UOp(Ops.CONST, dtypes.bool, arg=True, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
x10,
|
||||
x14:=UOp(Ops.CONST, dtypes.int, arg=-4, src=()),)),
|
||||
x14,)),
|
||||
UOp(Ops.IDIV, dtypes.int, arg=None, src=(
|
||||
x10,
|
||||
x9,)),)),
|
||||
x9,)),
|
||||
x14,)),
|
||||
x11,)),
|
||||
x5,
|
||||
UOp(Ops.IDIV, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MOD, dtypes.int, arg=None, src=(
|
||||
x19:=UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v2', 0, 32), src=()),
|
||||
UOp(Ops.CONST, dtypes.int, arg=3, src=()),)),
|
||||
x19,)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),)),
|
||||
x22:=UOp(Ops.CONST, dtypes.int, arg=-1, src=()),)),
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MOD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
x10,
|
||||
UOp(Ops.CONST, dtypes.int, arg=-2, src=()),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=6, src=()),)),
|
||||
UOp(Ops.MOD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_VAR, dtypes.int, arg=('v3', 0, 32), src=()),
|
||||
UOp(Ops.CONST, dtypes.int, arg=1, src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),
|
||||
x22,)),)),
|
||||
x22,))
|
||||
v1_val, v2_val, v3_val = UOp.const(dtypes.int, 9), UOp.const(dtypes.int, 0),UOp.const(dtypes.int, 0)
|
||||
x5 = (v1 <= 9).where(v1 * -4 - 4, v1 // 9) // 9
|
||||
expr = ((x5 >= -4).where(x5, (v2 % 3 + v2) // 5) * -1).maximum(((v1 * -2) % 6 + v3 % 1) * -1) * -1
|
||||
v1_val, v2_val, v3_val = v1.const_like(9), v2.const_like(0), v3.const_like(0)
|
||||
num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
|
||||
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
|
||||
self.assertEqual(num, rn)
|
||||
@@ -164,30 +118,8 @@ class TestFuzzFailure(unittest.TestCase):
|
||||
v1=Variable("v1", 0, 16)
|
||||
v2=Variable("v2", 0, 128)
|
||||
v3=Variable("v3", 0, 5)
|
||||
expr = UOp(Ops.MOD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.MOD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.index, arg=None, src=(
|
||||
x5:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v2', 0, 128), src=()),
|
||||
UOp(Ops.CONST, dtypes.index, arg=0, src=()),)),
|
||||
UOp(Ops.CONST, dtypes.index, arg=8, src=()),)),
|
||||
UOp(Ops.MUL, dtypes.index, arg=None, src=(
|
||||
x5,
|
||||
UOp(Ops.CONST, dtypes.index, arg=-2, src=()),)),)),
|
||||
x10:=UOp(Ops.CONST, dtypes.index, arg=5, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.index, arg=None, src=(
|
||||
UOp(Ops.IDIV, dtypes.index, arg=None, src=(
|
||||
x14:=UOp(Ops.DEFINE_VAR, dtypes.index, arg=('v1', 0, 16), src=()),
|
||||
UOp(Ops.CONST, dtypes.index, arg=6, src=()),)),
|
||||
UOp(Ops.CONST, dtypes.index, arg=4, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.index, arg=None, src=(
|
||||
x14,
|
||||
UOp(Ops.CONST, dtypes.index, arg=1, src=()),)),)),)),
|
||||
x10,))
|
||||
v1_val, v2_val, v3_val = UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 7),UOp.const(dtypes.int, 0)
|
||||
expr = (((v2 * 0).maximum(8) - v2 * 2) % 5 + v1 // 6 + v1 + 5) % 5
|
||||
v1_val, v2_val, v3_val = v1.const_like(0), v2.const_like(7), v3.const_like(0)
|
||||
num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
|
||||
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
|
||||
self.assertEqual(num, rn)
|
||||
|
||||
Reference in New Issue
Block a user