diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index e9b3d9a524..c1c636f120 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1331,7 +1331,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV", "METAL", "AMD", "HIP"]) + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 9824766b09..739e88b5b1 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -64,6 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase): "0", "(((lidx0+(gidx0*4))<19)!=True)") + @unittest.expectedFailure # FIXME def test_simplify_within_valid(self): ridx0 = Range(0, 4) ridx1 = Range(1, 4) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4b15265e0e..92c2ad851a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat -from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid +from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -483,8 +483,8 @@ reducer = PatternMatcher([ (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), - # simplify valid - (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + # # simplify valid + # (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification (UPat(UOps.LOAD, name="load"), simplify_image_load), # buffer load valid idx simplification