From 431a86615d4afacb66f94864b237e97ffade3731 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 1 Feb 2025 09:21:31 +0800 Subject: [PATCH] fix multi Ops.CONTIGUOUS_BACKWARD [pr] (#8843) --- test/test_multitensor.py | 3 +++ tinygrad/engine/multi.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 08fd62505f..896d24b80b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -288,6 +288,9 @@ class TestMultiTensor(unittest.TestCase): optim.step() out.numpy() + def test_backprop_conv_wino(self): + with Context(WINO=1): self.test_backprop_conv() + def test_backward_sum(self): x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0) w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2) diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 9234247592..4803127690 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -158,7 +158,8 @@ multi_pm = PatternMatcher([ (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi), (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi), - (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), + (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), + src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), ]) @track_rewrites(named=True)