From 08c9d980dc70cb4c24327857f7ff5191fbf07d7e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 2 Jan 2025 19:05:09 +0200 Subject: [PATCH] use const_like in uop zero folding [pr] (#8470) --- test/test_schedule.py | 2 -- tinygrad/ops.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index a900995a5d..ff8d0c7258 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2054,14 +2054,12 @@ class TestBigGraph(unittest.TestCase): assert UPat(Ops.CONST, arg=0).match(sink, {}), f"expected {sink} to collapse to a const 0" assert sink.shape == a.shape - @unittest.expectedFailure def test_const_folding_ne(self): a = Tensor([1]) sink = tensor_rewrite(a != a) assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" assert sink.shape == a.shape - @unittest.expectedFailure def test_const_folding_lt(self): a = Tensor([1]) sink = tensor_rewrite(a < a) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 14aa41f8e7..0dff5d3fce 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1192,9 +1192,9 @@ symbolic_simple = PatternMatcher([ (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), # ** zero folding ** - (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False + (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), - lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) + lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN