mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix copying padded const (#15116)
* fix const padding cpu * remove comment
This commit is contained in:
@@ -30,8 +30,7 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_copy_padded_const(self):
|
||||
schedule = Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").schedule()
|
||||
assert not any(si.ast.op is Ops.COPY for si in schedule), "const copy should be folded"
|
||||
# TODO: this is wrong, should be [0, 1, 1, 1, 1, 0]
|
||||
np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [1, 1, 1, 1, 1, 1])
|
||||
np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16
|
||||
|
||||
@@ -97,7 +97,6 @@ def contiguous_mops_to_view(c:UOp):
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)
|
||||
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
@@ -105,6 +104,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink()
|
||||
return out.after(c.replace(src=(fxn,)+tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in c.src[1:])+(out,), dtype=dtypes.void, tag=None))
|
||||
|
||||
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# transform precompiled CALLs
|
||||
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
||||
@@ -126,15 +126,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
|
||||
# replace CONTIGUOUS with ASSIGNs
|
||||
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# handle size 0
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
# early fixup const copy (TODO: is this wrong if there's a pad?)
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
])
|
||||
|
||||
def untag_and_append(ctx:AllocCtx, x:UOp):
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
@@ -96,6 +96,10 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
return c.src[0].substitute(dict_map, walk=True)
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# early fixup const copy
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
||||
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
||||
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
@@ -136,6 +140,14 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
|
||||
# ** size 0 **
|
||||
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# handle size 0
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user