push copy to disk (#12348)

This commit is contained in:
wozeparrot
2025-09-29 21:55:05 -07:00
committed by GitHub
parent 881709cd33
commit 2a0caa09c2
3 changed files with 19 additions and 4 deletions

View File

@@ -134,8 +134,7 @@ class GPT2:
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
for k in weights:
if k.endswith(transposed):
# TODO: it should not silently break without that .to(None)
weights[k] = weights[k].to(None).T
weights[k] = weights[k].T
# lm head and wte are tied
weights['lm_head.weight'] = weights['wte.weight']

View File

@@ -43,6 +43,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
return sched
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
@@ -2280,7 +2281,6 @@ class TestCopyFolding(unittest.TestCase):
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
@expect_rangeify_fails
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
@@ -2288,6 +2288,14 @@ class TestCopyFolding(unittest.TestCase):
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
@expect_nonrangeify_fails
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_after_shrink(self):
a = Tensor.arange(5)
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
@@ -2296,7 +2304,7 @@ class TestCopyFolding(unittest.TestCase):
# NOTE: disk permute must come after COPY
# TODO: this is wrong because of the permute
@unittest.expectedFailure
@expect_nonrangeify_fails
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")

View File

@@ -35,6 +35,14 @@ earliest_rewrites = double_reshape+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# remove contiguous on movement ops before a copy on disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# push copy past movement ops to disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:], tag=None),)+x.src[1:], tag=copy.tag) \
if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# COPY and source size need to match
# TODO: expand after copy creates issues with tagging
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),