From 2d72a4a90caa564bc03f07229ace8dff204b6479 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:39:45 +0800 Subject: [PATCH] fix copying padded const (#15116) * fix const padding cpu * remove comment --- test/backend/test_const_folding.py | 3 +-- tinygrad/engine/allocations.py | 13 +++---------- tinygrad/schedule/rangeify.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/test/backend/test_const_folding.py b/test/backend/test_const_folding.py index b9cffbb528..e6267d1b72 100644 --- a/test/backend/test_const_folding.py +++ b/test/backend/test_const_folding.py @@ -30,8 +30,7 @@ class TestMovedConstFolding(unittest.TestCase): def test_copy_padded_const(self): schedule = Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").schedule() assert not any(si.ast.op is Ops.COPY for si in schedule), "const copy should be folded" - # TODO: this is wrong, should be [0, 1, 1, 1, 1, 0] - np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [1, 1, 1, 1, 1, 1]) + np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [0, 1, 1, 1, 1, 0]) def test_cast_padded(self): # NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 20f20f4f32..3125dc6e81 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites +from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16 @@ -97,7 +97,6 @@ def contiguous_mops_to_view(c:UOp): # NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag) - def transform_precompiled_call(c:UOp) -> UOp|None: if not c.arg.precompile: return None if c.src[0].op is Ops.SINK: return None @@ -105,6 +104,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None: fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink() return out.after(c.replace(src=(fxn,)+tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in c.src[1:])+(out,), dtype=dtypes.void, tag=None)) +# NOTE: adding rules to here is bad. these all need to run before the schedule cache pm_early_transform_tensor_graph = PatternMatcher([ # transform precompiled CALLs (UPat(Ops.CALL, name="c"), transform_precompiled_call), @@ -126,15 +126,8 @@ pm_early_transform_tensor_graph = PatternMatcher([ (UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig), # replace CONTIGUOUS with ASSIGNs (UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign), - # remove DETACH/CONTIGUOUS_BACKWARD + # remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal) (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), - # reduce of size 0 is the identity element - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), - lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # handle size 0 - (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None), - # early fixup const copy (TODO: is this wrong if there's a pad?) - (UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None), ]) def untag_and_append(ctx:AllocCtx, x:UOp): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 348a78cf4c..5b839b14df 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace, Invalid from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call +from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, partition, get_single_element @@ -96,6 +96,10 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: return c.src[0].substitute(dict_map, walk=True) earliest_rewrites = mop_cleanup+PatternMatcher([ + # early fixup const copy + (UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))), + lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None), + # resolve calls (UPat(Ops.CALL, name="c"), resolve_call), @@ -136,6 +140,14 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # make source contiguous if it has hazardous movement ops on the dest buffer (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), + + # ** size 0 ** + + # reduce of size 0 is the identity element + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), + lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # handle size 0 + (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None), ]) # *****************