fix copying padded const (#15116)

* fix const padding cpu

* remove comment
This commit is contained in:
George Hotz
2026-03-04 10:39:45 +08:00
committed by GitHub
parent b5ebb4d06d
commit 2d72a4a90c
3 changed files with 17 additions and 13 deletions

View File

@@ -30,8 +30,7 @@ class TestMovedConstFolding(unittest.TestCase):
def test_copy_padded_const(self):
schedule = Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").schedule()
assert not any(si.ast.op is Ops.COPY for si in schedule), "const copy should be folded"
# TODO: this is wrong, should be [0, 1, 1, 1, 1, 0]
np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [1, 1, 1, 1, 1, 1])
np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [0, 1, 1, 1, 1, 0])
def test_cast_padded(self):
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16
@@ -97,7 +97,6 @@ def contiguous_mops_to_view(c:UOp):
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None
if c.src[0].op is Ops.SINK: return None
@@ -105,6 +104,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink()
return out.after(c.replace(src=(fxn,)+tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in c.src[1:])+(out,), dtype=dtypes.void, tag=None))
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
pm_early_transform_tensor_graph = PatternMatcher([
# transform precompiled CALLs
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
@@ -126,15 +126,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
# replace CONTIGUOUS with ASSIGNs
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
# remove DETACH/CONTIGUOUS_BACKWARD
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# early fixup const copy (TODO: is this wrong if there's a pad?)
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
])
def untag_and_append(ctx:AllocCtx, x:UOp):

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
@@ -96,6 +96,10 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
return c.src[0].substitute(dict_map, walk=True)
earliest_rewrites = mop_cleanup+PatternMatcher([
# early fixup const copy
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
@@ -136,6 +140,14 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
# ** size 0 **
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
])
# *****************