mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -37,7 +37,7 @@ add_tags = PatternMatcher([
|
||||
lambda a,c,dest: a.replace(src=(a.src[0], a.src[1].replace(src=(dest, c.rtag(())))), tag=a.tag+c.tag) if a.tag and c.tag else None),
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), name="x"), tag_uop),
|
||||
(UPat(Ops.AFTER, name="u"), apply_after),
|
||||
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), tag_uop),
|
||||
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx.bases else None),
|
||||
])
|
||||
|
||||
@@ -63,7 +63,7 @@ def replace_contig_with_store_after(u:UOp):
|
||||
|
||||
def replace_store_after_with_contig(u:UOp, src:UOp):
|
||||
assigned_to = u
|
||||
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp):
|
||||
@@ -124,10 +124,10 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
|
||||
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.STORE}, name="x"),
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.AFTER, Ops.STORE}, name="x"),
|
||||
lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
|
||||
# remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"),
|
||||
# remove extra CONTIGUOUS on AFTER (only when target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.AFTER, name="a"),), name="c"),
|
||||
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
|
||||
# replace AFTER+STORE with CONTIGUOUS when target is not a buffer
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="u"), replace_store_after_with_contig),
|
||||
@@ -143,7 +143,7 @@ def untag_and_append(ctx:AllocCtx, x:UOp):
|
||||
for t in x.tag:
|
||||
original_uop: UOp = ctx.uop_list[t]
|
||||
replace_uop = ret
|
||||
while replace_uop.op in {Ops.ASSIGN, Ops.AFTER}: replace_uop = replace_uop.src[0]
|
||||
while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0]
|
||||
ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
|
||||
if ret.op is not Ops.AFTER: ctx.assigns.append(ret) # AFTER gets appended by append_after
|
||||
return ret
|
||||
@@ -157,7 +157,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat({Ops.ASSIGN, Ops.AFTER}, name="x"), untag_and_append),
|
||||
(UPat(Ops.AFTER, name="x"), untag_and_append),
|
||||
(UPat(Ops.AFTER, name="x"), append_after),
|
||||
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
|
||||
# remove unique from const. TODO: this is copied in function.py
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
|
||||
|
||||
@@ -27,7 +27,7 @@ def realize_store_after_src(ctx:dict[UOp, None], dest:UOp, src:UOp):
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize),
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS}, name="tr"), realize),
|
||||
# realize AFTER of STORE+AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
|
||||
@@ -107,10 +107,6 @@ def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp):
|
||||
assert multi.axis is not None, "all multi ops have axis"
|
||||
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
|
||||
|
||||
def assign_multi(dest:UOp, src:UOp):
|
||||
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
|
||||
return dest.src[0].assign(src.src[0]).multi(src.axis)
|
||||
|
||||
def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis)
|
||||
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
@@ -137,7 +133,6 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), shrink_multi),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI), UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))))), store_after_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
|
||||
@@ -191,7 +191,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP}
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
@@ -383,20 +383,6 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
ended_stores.append(store_target.replace(dtype=sdtype).store(stores[0].src[1]).end(*end_rngs))
|
||||
return buf.after(*ended_stores)
|
||||
if (assign := x.src[0]).op is Ops.ASSIGN:
|
||||
assign_target, assign_src = assign.src[0], assign.src[1]
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
|
||||
|
||||
store_target = assign_target
|
||||
if assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
|
||||
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
|
||||
store_target = assign_target.src[0].src[0]
|
||||
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
ret = store_target.buf_uop.base
|
||||
if assign_src is not assign_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
|
||||
return ret
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
|
||||
@@ -85,7 +85,7 @@ class Ops(FastEnum):
|
||||
# ** 6 -- ops that don't exist in programs **
|
||||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
|
||||
UNIQUE = auto(); DEVICE = auto()
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
||||
@@ -307,8 +307,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
raise ValueError(f"invalid type for axis: {axis_arg}")
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
if self.op is Ops.ASSIGN: return self.src[1]._shape
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
|
||||
@@ -97,9 +97,6 @@ _tensor_spec = PatternMatcher([
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# ASSIGN is used internally by allreduce for precompiled function bodies
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
|
||||
Reference in New Issue
Block a user