mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix load collapse MAX to ADD (#13406)
* add Ops.ADD to pattern * add test
This commit is contained in:
@@ -2699,6 +2699,9 @@ class TestOps(unittest.TestCase):
|
||||
a = Tensor(3.14)
|
||||
np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy())
|
||||
|
||||
def test_stack_max(self):
|
||||
helper_test_op(None, lambda x, y: torch.stack((x, y)).max(axis=0)[0], lambda x, y: Tensor.stack(x, y).max(axis=0), vals=[[1.], [2.]])
|
||||
|
||||
def test_repeat(self):
|
||||
x = Tensor.randn(4, 6, 3)
|
||||
base_repeats = [2, 4, 3]
|
||||
|
||||
@@ -142,7 +142,7 @@ pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([
|
||||
# remove REDUCE on load, comes from indexing a tensor with another tensor
|
||||
def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self)
|
||||
pm_load_collapse = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse),
|
||||
(UPat(Ops.REDUCE, arg=Ops.ADD, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse),
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user