From 14d1c5fdfd332183282f82d4cf5fadea1102db6f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Mar 2026 15:21:51 -0500 Subject: [PATCH] assign fusion tests on detach and contiguous_backward (#15092) --- test/null/test_schedule.py | 12 ++++++++++++ test/unit/test_function.py | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/test/null/test_schedule.py b/test/null/test_schedule.py index 1a14c1d646..f5f63e19da 100644 --- a/test/null/test_schedule.py +++ b/test/null/test_schedule.py @@ -234,6 +234,18 @@ class TestSchedule(unittest.TestCase): d = Tensor.empty(1).assign(c) check_schedule(d, 1) + def test_detach_assign(self): + a = Tensor.ones(4, 4).contiguous().realize() + buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous() + r = buf2.assign(buf1.assign(a + 1.0) * 2.0) + check_schedule(r.detach().contiguous(), 2) + + def test_contiguous_backward_assign(self): + a = Tensor.ones(4, 4).contiguous().realize() + buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous() + r = buf2.assign(buf1.assign(a + 1.0) * 2.0) + check_schedule(r.contiguous_backward().contiguous(), 2) + def test_mulacc_relu_fusion(self): a = Tensor.empty(10) b = Tensor.empty(10) diff --git a/test/unit/test_function.py b/test/unit/test_function.py index e3cf7335de..216c5763fc 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -70,6 +70,14 @@ class TestFunction(unittest.TestCase): b = Tensor([4,5,6]) np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) + def test_contiguous_backward(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return (a + b).contiguous_backward() + + a = Tensor([1,2,3]) + b = Tensor([4,5,6]) + np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) + def test_method(self): class Foo: def __init__(self): self.w = Tensor([10,20,30])