mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
allocations contiguous is store+after (#15280)
This commit is contained in:
@@ -280,6 +280,11 @@ class TestAssign(unittest.TestCase):
|
||||
t.uop = t.uop.after(t[:5].uop.assign(Tensor.ones(5).uop))
|
||||
np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.])
|
||||
|
||||
def test_assign_after_target_chain(self):
|
||||
t = Tensor.arange(16).reshape(4, 4).permute(1, 0).contiguous()
|
||||
t.assign(t + 100)
|
||||
np.testing.assert_equal(t.numpy(), [[100, 104, 108, 112], [101, 105, 109, 113], [102, 106, 110, 114], [103, 107, 111, 115]])
|
||||
|
||||
def test_assign_contiguous(self):
|
||||
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
||||
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
||||
|
||||
@@ -46,18 +46,19 @@ def _buffer_like(u:UOp) -> UOp:
|
||||
if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0:
|
||||
if DEBUG >= 1: print(f"demoting Image {dtype} with shape {u.max_shard_shape}")
|
||||
dtype = dtype.base
|
||||
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape)
|
||||
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape).shrink_to(u.shard_shape)
|
||||
if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis)
|
||||
return buffer
|
||||
|
||||
def replace_contig_with_assign(u:UOp):
|
||||
def replace_contig_with_store_after(u:UOp):
|
||||
# can't allocate a buffer without a device (e.g., inside a CALL function body with only PARAMs)
|
||||
if u._device is None: return None
|
||||
# if size is 0, remove the contig
|
||||
if u.size == 0: return u.src[0]
|
||||
# no real contig for DISK/TINYFS tensors, they are left alone
|
||||
if isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS")): return u.rtag(None)
|
||||
return _buffer_like(u).assign(u.src[0]).rtag(u.tag)
|
||||
buf = _buffer_like(u)
|
||||
return buf.after(buf.store(u.src[0])).rtag(u.tag)
|
||||
|
||||
def replace_assign_with_contig(u:UOp):
|
||||
assigned_to = u
|
||||
@@ -111,14 +112,15 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
|
||||
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
|
||||
# remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"),
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.STORE}, name="x"),
|
||||
lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
|
||||
# remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"),
|
||||
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
|
||||
# replace ASSIGN with CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
|
||||
# replace CONTIGUOUS with ASSIGNs
|
||||
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
|
||||
# replace CONTIGUOUS with STORE+AFTER
|
||||
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after),
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
@@ -129,9 +131,9 @@ def untag_and_append(ctx:AllocCtx, x:UOp):
|
||||
for t in x.tag:
|
||||
original_uop: UOp = ctx.uop_list[t]
|
||||
replace_uop = ret
|
||||
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
|
||||
while replace_uop.op in {Ops.ASSIGN, Ops.AFTER}: replace_uop = replace_uop.src[0]
|
||||
ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
|
||||
ctx.assigns.append(ret)
|
||||
if ret.op is not Ops.AFTER: ctx.assigns.append(ret) # AFTER gets appended by append_after
|
||||
return ret
|
||||
|
||||
def append_after(ctx:AllocCtx, x:UOp):
|
||||
@@ -143,7 +145,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
|
||||
(UPat({Ops.ASSIGN, Ops.AFTER}, name="x"), untag_and_append),
|
||||
(UPat(Ops.AFTER, name="x"), append_after),
|
||||
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
|
||||
# remove unique from const. TODO: this is copied in function.py
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
|
||||
|
||||
@@ -25,6 +25,10 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
if buf.base in x.backward_slice_with_self: ctx[x] = None
|
||||
|
||||
def unrealize_store_src(ctx:dict[UOp, None], x:UOp):
|
||||
"""Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them."""
|
||||
if x in ctx: del ctx[x]
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize),
|
||||
@@ -34,6 +38,8 @@ pm_generate_realize_map = PatternMatcher([
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
# don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN)
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -60,13 +66,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
# TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges
|
||||
if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[x]
|
||||
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||||
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges])
|
||||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or \
|
||||
(s.op is Ops.AFTER and not any(c.op in {Ops.STORE, Ops.END} for c in s.src[1:])):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
|
||||
@@ -45,9 +45,10 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp):
|
||||
else: break
|
||||
ctx[x] = assign
|
||||
|
||||
# *** fold moved ASSIGNs (hack for openpilot) ***
|
||||
# *** fold moved ASSIGNs/AFTERs (hack for openpilot) ***
|
||||
pm_fold_moved_assign = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign),
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign),
|
||||
# replace ALU sources with assign versions found above
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
@@ -57,9 +58,10 @@ pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if len(idx.src[1:]) == len(r.shape) else None),
|
||||
# move movement ops after AFTER
|
||||
# move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
|
||||
if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
|
||||
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
@@ -69,12 +71,12 @@ pm_mops = PatternMatcher([
|
||||
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)):
|
||||
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
|
||||
return assign.replace(src=(target, src.contiguous()))
|
||||
|
||||
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
|
||||
root_target = target
|
||||
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
|
||||
while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0]
|
||||
# when RHS depends on the previous assign result, break with contiguous
|
||||
if target in src.toposort(): src = src.contiguous()
|
||||
return assign.replace(src=(root_target, src))
|
||||
@@ -170,8 +172,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
|
||||
lambda target, src: target.assign(src.bitcast(target.dtype))),
|
||||
|
||||
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
||||
# if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target
|
||||
(UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"),
|
||||
normalize_assign_target_chain),
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
@@ -192,8 +195,8 @@ ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
# don't optimize ALWAYS_RUN_OPS
|
||||
if b.src[0].op in ALWAYS_RUN_OPS: return None
|
||||
# don't optimize ALWAYS_RUN_OPS or AFTER (AFTER is a buffer identity — ranges define consumer access, not computation)
|
||||
if b.src[0].op in ALWAYS_RUN_OPS or b.src[0].op is Ops.AFTER: return None
|
||||
|
||||
new_rng = []
|
||||
hit = False
|
||||
@@ -367,6 +370,11 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
|
||||
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
||||
# AFTER: add END to the existing STORE, return buffer with kernel dependency
|
||||
if x.src[0].op is Ops.AFTER:
|
||||
buf = x.src[0].src[0].buf_uop.base
|
||||
stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE]
|
||||
return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf
|
||||
if (assign := x.src[0]).op is Ops.ASSIGN:
|
||||
assign_target, assign_src = assign.src[0], assign.src[1]
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
@@ -529,6 +537,8 @@ pm_add_range_tags = PatternMatcher([
|
||||
def split_store(x:UOp) -> UOp|None:
|
||||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
|
||||
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
|
||||
@@ -78,7 +78,7 @@ movement_ops = PatternMatcher([
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# AFTER on Movement Op or ASSIGN
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
@@ -233,8 +233,8 @@ program_spec = PatternMatcher([
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
# make sure all index dtypes have been lowered (except CONST/RANGE/DEFINE_VAR which are valid index-typed)
|
||||
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
|
||||
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
Reference in New Issue
Block a user