allocations contiguous is store+after (#15280)

This commit is contained in:
chenyu
2026-03-15 11:58:40 -04:00
committed by GitHub
parent 7b6211fdd7
commit cd14e8e64b
5 changed files with 49 additions and 31 deletions

View File

@@ -280,6 +280,11 @@ class TestAssign(unittest.TestCase):
t.uop = t.uop.after(t[:5].uop.assign(Tensor.ones(5).uop))
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.])
def test_assign_after_target_chain(self):
t = Tensor.arange(16).reshape(4, 4).permute(1, 0).contiguous()
t.assign(t + 100)
np.testing.assert_equal(t.numpy(), [[100, 104, 108, 112], [101, 105, 109, 113], [102, 106, 110, 114], [103, 107, 111, 115]])
def test_assign_contiguous(self):
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)

View File

@@ -46,18 +46,19 @@ def _buffer_like(u:UOp) -> UOp:
if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0:
if DEBUG >= 1: print(f"demoting Image {dtype} with shape {u.max_shard_shape}")
dtype = dtype.base
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape)
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape).shrink_to(u.shard_shape)
if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis)
return buffer
def replace_contig_with_assign(u:UOp):
def replace_contig_with_store_after(u:UOp):
# can't allocate a buffer without a device (e.g., inside a CALL function body with only PARAMs)
if u._device is None: return None
# if size is 0, remove the contig
if u.size == 0: return u.src[0]
# no real contig for DISK/TINYFS tensors, they are left alone
if isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS")): return u.rtag(None)
return _buffer_like(u).assign(u.src[0]).rtag(u.tag)
buf = _buffer_like(u)
return buf.after(buf.store(u.src[0])).rtag(u.tag)
def replace_assign_with_contig(u:UOp):
assigned_to = u
@@ -111,14 +112,15 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
# add CONTIGUOUS to tagged UOps
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
# remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous)
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"),
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.STORE}, name="x"),
lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
# remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous)
(UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"),
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
# replace ASSIGN with CONTIGUOUS
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
# replace CONTIGUOUS with ASSIGNs
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
# replace CONTIGUOUS with STORE+AFTER
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after),
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
])
@@ -129,9 +131,9 @@ def untag_and_append(ctx:AllocCtx, x:UOp):
for t in x.tag:
original_uop: UOp = ctx.uop_list[t]
replace_uop = ret
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
while replace_uop.op in {Ops.ASSIGN, Ops.AFTER}: replace_uop = replace_uop.src[0]
ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
ctx.assigns.append(ret)
if ret.op is not Ops.AFTER: ctx.assigns.append(ret) # AFTER gets appended by append_after
return ret
def append_after(ctx:AllocCtx, x:UOp):
@@ -143,7 +145,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
(UPat({Ops.ASSIGN, Ops.AFTER}, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), append_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
# remove unique from const. TODO: this is copied in function.py

View File

@@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
@@ -25,6 +25,10 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
if buf.base in x.backward_slice_with_self: ctx[x] = None
def unrealize_store_src(ctx:dict[UOp, None], x:UOp):
"""Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them."""
if x in ctx: del ctx[x]
pm_generate_realize_map = PatternMatcher([
# always realize
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize),
@@ -34,6 +38,8 @@ pm_generate_realize_map = PatternMatcher([
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# sometimes realize src of assign
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
# don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN)
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src),
])
@dataclass(frozen=True)
@@ -60,13 +66,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
new_srcs = []
for s in x.src:
new_src = s
# TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges
if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map:
realized_ranges = ctx.realize_map[x]
assert isinstance(realized_ranges, list), "realize map must contain range list"
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges])
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or \
(s.op is Ops.AFTER and not any(c.op in {Ops.STORE, Ops.END} for c in s.src[1:])):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s]

View File

@@ -45,9 +45,10 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp):
else: break
ctx[x] = assign
# *** fold moved ASSIGNs (hack for openpilot) ***
# *** fold moved ASSIGNs/AFTERs (hack for openpilot) ***
pm_fold_moved_assign = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign),
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign),
# replace ALU sources with assign versions found above
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
@@ -57,9 +58,10 @@ pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
if len(idx.src[1:]) == len(r.shape) else None),
# move movement ops after AFTER
# move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
@@ -69,12 +71,12 @@ pm_mops = PatternMatcher([
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)):
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
return assign.replace(src=(target, src.contiguous()))
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src))
@@ -170,8 +172,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
# if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"),
normalize_assign_target_chain),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
@@ -192,8 +195,8 @@ ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP}
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
# don't optimize ALWAYS_RUN_OPS
if b.src[0].op in ALWAYS_RUN_OPS: return None
# don't optimize ALWAYS_RUN_OPS or AFTER (AFTER is a buffer identity — ranges define consumer access, not computation)
if b.src[0].op in ALWAYS_RUN_OPS or b.src[0].op is Ops.AFTER: return None
new_rng = []
hit = False
@@ -367,6 +370,11 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
# AFTER: add END to the existing STORE, return buffer with kernel dependency
if x.src[0].op is Ops.AFTER:
buf = x.src[0].src[0].buf_uop.base
stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE]
return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf
if (assign := x.src[0]).op is Ops.ASSIGN:
assign_target, assign_src = assign.src[0], assign.src[1]
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
@@ -529,6 +537,8 @@ pm_add_range_tags = PatternMatcher([
def split_store(x:UOp) -> UOp|None:
# if we have any open ranges here, we don't split
if x.ranges: return None
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
# local kernel rewrite
lctx = LocalAddBufferContext()

View File

@@ -78,7 +78,7 @@ movement_ops = PatternMatcher([
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# AFTER on Movement Op or ASSIGN
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True),
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True),
])
_tensor_spec = PatternMatcher([
@@ -233,8 +233,8 @@ program_spec = PatternMatcher([
# END closes ranges
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
# make sure all index dtypes have been lowered (except CONST/RANGE/DEFINE_VAR which are valid index-typed)
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),