mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assign fusion tests on detach and contiguous_backward (#15092)
This commit is contained in:
@@ -234,6 +234,18 @@ class TestSchedule(unittest.TestCase):
|
||||
d = Tensor.empty(1).assign(c)
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_detach_assign(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
|
||||
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
|
||||
check_schedule(r.detach().contiguous(), 2)
|
||||
|
||||
def test_contiguous_backward_assign(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
|
||||
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
|
||||
check_schedule(r.contiguous_backward().contiguous(), 2)
|
||||
|
||||
def test_mulacc_relu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
||||
@@ -70,6 +70,14 @@ class TestFunction(unittest.TestCase):
|
||||
b = Tensor([4,5,6])
|
||||
np.testing.assert_equal(f(a, b).numpy(), [5,7,9])
|
||||
|
||||
def test_contiguous_backward(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return (a + b).contiguous_backward()
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
np.testing.assert_equal(f(a, b).numpy(), [5,7,9])
|
||||
|
||||
def test_method(self):
|
||||
class Foo:
|
||||
def __init__(self): self.w = Tensor([10,20,30])
|
||||
|
||||
Reference in New Issue
Block a user