diff --git a/examples/gpt2.py b/examples/gpt2.py index 8f1a2836af..7b508c1b3a 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -134,8 +134,7 @@ class GPT2: transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight') for k in weights: if k.endswith(transposed): - # TODO: it should not silently break without that .to(None) - weights[k] = weights[k].to(None).T + weights[k] = weights[k].T # lm head and wte are tied weights['lm_head.weight'] = weights['wte.weight'] diff --git a/test/test_schedule.py b/test/test_schedule.py index 2c29782105..ffcca8470c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -43,6 +43,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te return sched def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn) +def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn) def _realize_weights(m): for p in nn.state.get_parameters(m): p.realize() @@ -2280,7 +2281,6 @@ class TestCopyFolding(unittest.TestCase): b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) - @expect_rangeify_fails def test_permute_on_disk(self): with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") @@ -2288,6 +2288,14 @@ class TestCopyFolding(unittest.TestCase): b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + @expect_nonrangeify_fails + def test_permute_on_disk_contiguous(self): + with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) + a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") + b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + def test_permute_after_shrink(self): a = Tensor.arange(5) b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU") @@ -2296,7 +2304,7 @@ class TestCopyFolding(unittest.TestCase): # NOTE: disk permute must come after COPY # TODO: this is wrong because of the permute - @unittest.expectedFailure + @expect_nonrangeify_fails def test_permute_after_shrink_on_disk(self): with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c77fbf7d37..305d0fa8a8 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -35,6 +35,14 @@ earliest_rewrites = double_reshape+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # remove contiguous on movement ops before a copy on disk + (UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"), + lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None), + # push copy past movement ops to disk + (UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.COPY, allow_any_len=True, name="copy"), + lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:], tag=None),)+x.src[1:], tag=copy.tag) \ + if isinstance(x.device, str) and x.device.startswith("DISK") else None), + # COPY and source size need to match # TODO: expand after copy creates issues with tagging (UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),