fix multi Ops.CONTIGUOUS_BACKWARD [pr] (#8843)

This commit is contained in:
George Hotz
2025-02-01 09:21:31 +08:00
committed by GitHub
parent 07d3676019
commit 431a86615d
2 changed files with 5 additions and 1 deletions

View File

@@ -288,6 +288,9 @@ class TestMultiTensor(unittest.TestCase):
optim.step()
out.numpy()
def test_backprop_conv_wino(self):
with Context(WINO=1): self.test_backprop_conv()
def test_backward_sum(self):
x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0)
w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2)

View File

@@ -158,7 +158,8 @@ multi_pm = PatternMatcher([
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])
@track_rewrites(named=True)