mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
push copy to disk (#12348)
This commit is contained in:
@@ -134,8 +134,7 @@ class GPT2:
|
||||
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
|
||||
for k in weights:
|
||||
if k.endswith(transposed):
|
||||
# TODO: it should not silently break without that .to(None)
|
||||
weights[k] = weights[k].to(None).T
|
||||
weights[k] = weights[k].T
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = weights['wte.weight']
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
||||
return sched
|
||||
|
||||
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
|
||||
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)
|
||||
|
||||
def _realize_weights(m):
|
||||
for p in nn.state.get_parameters(m): p.realize()
|
||||
@@ -2280,7 +2281,6 @@ class TestCopyFolding(unittest.TestCase):
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_permute_on_disk(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
@@ -2288,6 +2288,14 @@ class TestCopyFolding(unittest.TestCase):
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
@expect_nonrangeify_fails
|
||||
def test_permute_on_disk_contiguous(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_after_shrink(self):
|
||||
a = Tensor.arange(5)
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
||||
@@ -2296,7 +2304,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
|
||||
# NOTE: disk permute must come after COPY
|
||||
# TODO: this is wrong because of the permute
|
||||
@unittest.expectedFailure
|
||||
@expect_nonrangeify_fails
|
||||
def test_permute_after_shrink_on_disk(self):
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
||||
|
||||
@@ -35,6 +35,14 @@ earliest_rewrites = double_reshape+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
|
||||
# remove contiguous on movement ops before a copy on disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
# push copy past movement ops to disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:], tag=None),)+x.src[1:], tag=copy.tag) \
|
||||
if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
|
||||
# COPY and source size need to match
|
||||
# TODO: expand after copy creates issues with tagging
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
|
||||
Reference in New Issue
Block a user