mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix multi Ops.CONTIGUOUS_BACKWARD [pr] (#8843)
This commit is contained in:
@@ -288,6 +288,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
optim.step()
|
||||
out.numpy()
|
||||
|
||||
def test_backprop_conv_wino(self):
|
||||
with Context(WINO=1): self.test_backprop_conv()
|
||||
|
||||
def test_backward_sum(self):
|
||||
x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0)
|
||||
w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2)
|
||||
|
||||
@@ -158,7 +158,8 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
Reference in New Issue
Block a user