Revert "big sink is on base (#14819)" (#14825)

This reverts commit 5fc3d8109f.
This commit is contained in:
George Hotz
2026-02-17 19:18:06 +08:00
committed by GitHub
parent f8e485ee9e
commit ff60dab622
2 changed files with 6 additions and 6 deletions

View File

@@ -752,7 +752,7 @@ class TestMultiTensor(unittest.TestCase):
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).contiguous().schedule()
t0.reshape((26*15,7)).schedule()
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
@@ -982,18 +982,18 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8)))
a.contiguous().schedule()
a.schedule()
with self.assertRaises(AssertionError):
# cannot shrink sharded and non-sharded axis at the same time
a = t.shrink(((0, 2), (2, 4)))
a.contiguous().schedule()
a.schedule()
a = t.shrink(((0, 2), (0, 8)))
a.contiguous().schedule()
a.schedule()
assert a.shape == (2, 8)
p = a.pad(((0, 6), (0, 0)))
p.contiguous().schedule()
p.schedule()
assert p.shape == (8, 8)
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))

View File

@@ -255,7 +255,7 @@ class Tensor(OpMixin):
NOTE: A Tensor can only be scheduled once.
"""
big_sink = UOp.sink(*[x.uop.base for x in (self,)+lst])
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)