disable simplify_valid (#7312)

fixed test_failure_55. will reenable it later after fixing the bug
This commit is contained in:
chenyu
2024-10-26 12:42:48 -04:00
committed by GitHub
parent 293714610a
commit 0a4d01f6d4
3 changed files with 5 additions and 4 deletions

View File

@@ -1331,7 +1331,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV", "METAL", "AMD", "HIP"]) helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -64,6 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase):
"0", "0",
"(((lidx0+(gidx0*4))<19)!=True)") "(((lidx0+(gidx0*4))<19)!=True)")
@unittest.expectedFailure # FIXME
def test_simplify_within_valid(self): def test_simplify_within_valid(self):
ridx0 = Range(0, 4) ridx0 = Range(0, 4)
ridx1 = Range(1, 4) ridx1 = Range(1, 4)

View File

@@ -4,7 +4,7 @@ import functools, itertools, operator
from collections import defaultdict from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -483,8 +483,8 @@ reducer = PatternMatcher([
(UPat(UOps.STORE, name="root"), delete_redundant_gates), (UPat(UOps.STORE, name="root"), delete_redundant_gates),
# late fixup of unfoldable image loads # late fixup of unfoldable image loads
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid # # simplify valid
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
# image load valid idx simplification # image load valid idx simplification
(UPat(UOps.LOAD, name="load"), simplify_image_load), (UPat(UOps.LOAD, name="load"), simplify_image_load),
# buffer load valid idx simplification # buffer load valid idx simplification