multi: instead of real, just copy (#10289)

* multi: instead of real, just copy

* fix test

* remove real
This commit is contained in:
George Hotz
2025-05-14 10:36:55 -07:00
committed by GitHub
parent 043efc6ec4
commit 42e70193c9
3 changed files with 25 additions and 45 deletions

View File

@@ -844,6 +844,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a.schedule()
assert a.shape == (2, 8)
# real is no longer used, so these are on None and we can pad them however
"""
with self.assertRaises(AssertionError):
# cannot pad sharded and non-sharded axis at the same time
p = a.pad(((0, 6), (0, 1)))
@@ -853,6 +855,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# can only pad to whole axis
p = a.pad(((1, 5), (0, 0)))
p.schedule()
"""
p = a.pad(((0, 6), (0, 0)))
p.schedule()