diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 3e1228c6d2..eeabfadeb3 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -280,6 +280,11 @@ class TestAssign(unittest.TestCase): t.uop = t.uop.after(t[:5].uop.assign(Tensor.ones(5).uop)) np.testing.assert_allclose(t.numpy(), [1.,1.,1.,1.,1.,0.,0.,0.,0.,0.]) + def test_assign_after_target_chain(self): + t = Tensor.arange(16).reshape(4, 4).permute(1, 0).contiguous() + t.assign(t + 100) + np.testing.assert_equal(t.numpy(), [[100, 104, 108, 112], [101, 105, 109, 113], [102, 106, 110, 114], [103, 107, 111, 115]]) + def test_assign_contiguous(self): b = Tensor.arange(16).reshape(4,4).contiguous().realize() a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 0ce15dfb60..7ddabf79f5 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -46,18 +46,19 @@ def _buffer_like(u:UOp) -> UOp: if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0: if DEBUG >= 1: print(f"demoting Image {dtype} with shape {u.max_shard_shape}") dtype = dtype.base - buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape) + buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape).shrink_to(u.shard_shape) if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis) return buffer -def replace_contig_with_assign(u:UOp): +def replace_contig_with_store_after(u:UOp): # can't allocate a buffer without a device (e.g., inside a CALL function body with only PARAMs) if u._device is None: return None # if size is 0, remove the contig if u.size == 0: return u.src[0] # no real contig for DISK/TINYFS tensors, they are left alone if isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS")): return u.rtag(None) - return _buffer_like(u).assign(u.src[0]).rtag(u.tag) + buf = _buffer_like(u) + return buf.after(buf.store(u.src[0])).rtag(u.tag) def replace_assign_with_contig(u:UOp): assigned_to = u @@ -111,14 +112,15 @@ pm_early_transform_tensor_graph = PatternMatcher([ (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view), # add CONTIGUOUS to tagged UOps - (UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)), - # remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous) - (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"), + (UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.STORE}, name="x"), + lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)), + # remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous) + (UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"), lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None), # replace ASSIGN with CONTIGUOUS (UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig), - # replace CONTIGUOUS with ASSIGNs - (UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign), + # replace CONTIGUOUS with STORE+AFTER + (UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after), # remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal) (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), ]) @@ -129,9 +131,9 @@ def untag_and_append(ctx:AllocCtx, x:UOp): for t in x.tag: original_uop: UOp = ctx.uop_list[t] replace_uop = ret - while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0] + while replace_uop.op in {Ops.ASSIGN, Ops.AFTER}: replace_uop = replace_uop.src[0] ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape) - ctx.assigns.append(ret) + if ret.op is not Ops.AFTER: ctx.assigns.append(ret) # AFTER gets appended by append_after return ret def append_after(ctx:AllocCtx, x:UOp): @@ -143,7 +145,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp): b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None) pm_finalize_call = PatternMatcher([ - (UPat(Ops.ASSIGN, name="x"), untag_and_append), + (UPat({Ops.ASSIGN, Ops.AFTER}, name="x"), untag_and_append), (UPat(Ops.AFTER, name="x"), append_after), (UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None), # remove unique from const. TODO: this is copied in function.py diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 7d9438c5af..8e2fcb96bb 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored -ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, +ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL} @@ -25,6 +25,10 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): # you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce if buf.base in x.backward_slice_with_self: ctx[x] = None +def unrealize_store_src(ctx:dict[UOp, None], x:UOp): + """Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them.""" + if x in ctx: del ctx[x] + pm_generate_realize_map = PatternMatcher([ # always realize (UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize), @@ -34,6 +38,8 @@ pm_generate_realize_map = PatternMatcher([ (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), # sometimes realize src of assign (UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src), + # don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN) + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src), ]) @dataclass(frozen=True) @@ -60,13 +66,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): new_srcs = [] for s in x.src: new_src = s - # TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges - if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map: - realized_ranges = ctx.realize_map[x] - assert isinstance(realized_ranges, list), "realize map must contain range list" - closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges]) - new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE]) - elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}: + if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or \ + (s.op is Ops.AFTER and not any(c.op in {Ops.STORE, Ops.END} for c in s.src[1:])): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: realized_ranges = ctx.realize_map[s] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9fe47c72e0..39d54d97cf 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -45,9 +45,10 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp): else: break ctx[x] = assign -# *** fold moved ASSIGNs (hack for openpilot) *** +# *** fold moved ASSIGNs/AFTERs (hack for openpilot) *** pm_fold_moved_assign = PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign), + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign), # replace ALU sources with assign versions found above (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) @@ -57,9 +58,10 @@ pm_mops = PatternMatcher([ (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) if len(idx.src[1:]) == len(r.shape) else None), - # move movement ops after AFTER + # move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after) (UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True), - lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)), + lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg) + if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None), (UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), ]) @@ -69,12 +71,12 @@ pm_mops = PatternMatcher([ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) - if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)): + if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)): return assign.replace(src=(target, src.contiguous())) def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): root_target = target - while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] + while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0] # when RHS depends on the previous assign result, break with contiguous if target in src.toposort(): src = src.contiguous() return assign.replace(src=(root_target, src)) @@ -170,8 +172,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))), lambda target, src: target.assign(src.bitcast(target.dtype))), - # if assign target is itself an ASSIGN chain, canonicalize to the original buffer target - (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain), + # if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target + (UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), + normalize_assign_target_chain), # make source contiguous if it has hazardous movement ops on the dest buffer (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), @@ -192,8 +195,8 @@ ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP} # you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left def cleanup_dead_axes(b:UOp): - # don't optimize ALWAYS_RUN_OPS - if b.src[0].op in ALWAYS_RUN_OPS: return None + # don't optimize ALWAYS_RUN_OPS or AFTER (AFTER is a buffer identity — ranges define consumer access, not computation) + if b.src[0].op in ALWAYS_RUN_OPS or b.src[0].op is Ops.AFTER: return None new_rng = [] hit = False @@ -367,6 +370,11 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}" sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) + # AFTER: add END to the existing STORE, return buffer with kernel dependency + if x.src[0].op is Ops.AFTER: + buf = x.src[0].src[0].buf_uop.base + stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE] + return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf if (assign := x.src[0]).op is Ops.ASSIGN: assign_target, assign_src = assign.src[0], assign.src[1] assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" @@ -529,6 +537,8 @@ pm_add_range_tags = PatternMatcher([ def split_store(x:UOp) -> UOp|None: # if we have any open ranges here, we don't split if x.ranges: return None + # raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently + if x.op is Ops.STORE and x.src[0]._shape is not None: return None # local kernel rewrite lctx = LocalAddBufferContext() diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index c3f7609a68..ccba8a1767 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -78,7 +78,7 @@ movement_ops = PatternMatcher([ (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), # AFTER on Movement Op or ASSIGN - (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True), + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True), ]) _tensor_spec = PatternMatcher([ @@ -233,8 +233,8 @@ program_spec = PatternMatcher([ # END closes ranges (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), - # make sure all index dtypes have been lowered - (UPat(GroupOp.All, dtype=dtypes.index), lambda: False), + # make sure all index dtypes have been lowered (except CONST/RANGE/DEFINE_VAR which are valid index-typed) + (UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False), (UPat(Ops.CONST, arg=Invalid), lambda: False), (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),