From d64af3c884cabdf2473a28d7c3ef2e137477fb86 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:19:52 -0500 Subject: [PATCH 01/11] reorder simplifier and grouper logic in scheduler [pr] (#8861) --- tinygrad/engine/schedule.py | 164 +++++++++++++++++------------------- 1 file changed, 79 insertions(+), 85 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1083240ea7..f5bc8612ee 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -203,7 +203,57 @@ if CAPTURE_PROCESS_REPLAY: def save_process_replay() -> None: for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) -# **** Schedule grouping +# **** UOp realization + +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + +def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store + +def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: + st = unwrap(view.st) + # fold simple pads + if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): + return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) + # early realize before expand + if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) + # otherwise safety check pads + return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) + +def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: + if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None + del ctx.realizes[b] + return x.view(unwrap(view.st)) + +def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None + buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) + return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) + +do_realize = PatternMatcher([ + # always realize SINK parents + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), + # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW + (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), + # realize before expand or unsafe pad ops + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), + # don't realize image to image casts + (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), + fold_img_cast), + # realize before COPY or BUFFER_VIEW + (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), + # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK + (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), +]) + +def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: + ctx.allbufs[buf_uop] = view + if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) + for x in op.base.src: + if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None +create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER def uval(u:UOp) -> UOp: @@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) -def group_realizes(ctx:ScheduleContext) -> None: - """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop""" +def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]: + # start by adding uops that always realize + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: dict[UOp, UOp] = {} double_reduces: list[UOp] = [] @@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None: for reduceop in double_reduces: top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] + graph_rewrite(sink, break_sched, ctx) + return ctx.realizes -# **** Schedule creation and BFS toposort +# break the SINK into stores -# ** this is schedule level const folding +def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): + # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN + return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) + +def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m + if b not in ctx.realizes: return x # collapse BUFFER + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) + return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) + +break_sched = PatternMatcher([ + # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), + (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), +]) + +# **** schedule simplifier def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: if not all_int(x.shape): return None @@ -338,80 +407,6 @@ sym = symbolic_simple+PatternMatcher([ if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) -# ** this decides which ops get realized - -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - -def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store - -def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: - st = unwrap(view.st) - # fold simple pads - if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): - return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src) - # early realize before expand - if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src) - # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src) - -def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None: - if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None - del ctx.realizes[b] - return x.view(unwrap(view.st)) - -def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None - buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) - return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) - -do_realize = PatternMatcher([ - # always realize SINK parents - (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)), - # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW - (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize), - # realize before expand or unsafe pad ops - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view), - # don't realize image to image casts - (UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)), - fold_img_cast), - # realize before COPY or BUFFER_VIEW - (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize), - # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK - (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), -]) - -# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp - -def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): - # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN - return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) - -def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m - if b not in ctx.realizes: return x # collapse BUFFER - ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) - return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) - -break_sched = PatternMatcher([ - # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), - (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), -]) - -# **** Schedule context builder - -def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: - ctx.allbufs[buf_uop] = view - if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) - for x in op.base.src: - if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None -create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) - -# **** movement ops - remove_movement_ops = merge_views+PatternMatcher([ # NOTE: movement ops are always applied to base (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), @@ -420,6 +415,8 @@ remove_movement_ops = merge_views+PatternMatcher([ lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), ]) +# **** schedule creation and toposort + @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={}) @@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) - # add realizes - sink = graph_rewrite(sink, do_realize+create_ctx, ctx) - # group realizes into kernels - group_realizes(ctx) - graph_rewrite(sink, break_sched, ctx) + # get realizes + realize_map = group_realizes(sink, ctx) # TODO: this should be the break between the "grouper" and the "linearizer" # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) @@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} - for buf_uop,store in ctx.realizes.items(): + for buf_uop,store in realize_map.items(): assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once From 565c37c681f19f26f1f6e47598fcd251bb0bbf2e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:11:36 -0500 Subject: [PATCH 02/11] start simplifying the scheduler context [pr] (#8830) --- tinygrad/engine/schedule.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f5bc8612ee..e0aebefb35 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -35,7 +35,7 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop + tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op @@ -45,16 +45,16 @@ class ScheduleContext: # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: +def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st)) # VIEW is passthrough if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) + cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st)) return ret # make things that can't be images not images dtype = buf.dtype @@ -64,9 +64,9 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c # ASSIGN already has a target buffer, otherwise we create a new one assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) - # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = tensor_map[buf] + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) + # track the buffer uop for the simplified uop + buffer_map[buf] = buf_uop # (early) bufferize cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret @@ -431,12 +431,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v # we group the rest of UOps into ScheduleItems - rev_tensor_map: dict[UOp, list[UOp]] = {} - for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) - # add BUFFER uops - sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) + buffer_map: dict[UOp, UOp] = {} + sink = add_buffers(tensor_map[big_sink], buffer_map, cache={}) # get realizes - realize_map = group_realizes(sink, ctx) + buf_tensors: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): + if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k) + realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors)) # TODO: this should be the break between the "grouper" and the "linearizer" # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) @@ -449,7 +450,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once - for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) # increment refcount for this buffer buf_uop.buffer.ref(1) From af2c2837f642b6f68c06050c28b755c85815b711 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 3 Feb 2025 14:02:55 +0800 Subject: [PATCH 03/11] hotfix: skip broken test, add KERNEL Op --- test/unit/test_disk_tensor.py | 3 ++- tinygrad/ops.py | 2 +- tinygrad/viz/serve.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 6bb099aa76..6291216d1d 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -3,7 +3,7 @@ import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load -from tinygrad.helpers import Timing, fetch, temp, CI +from tinygrad.helpers import Timing, fetch, temp, CI, OSX from tinygrad.device import is_dtype_supported def compare_weights_both(url): @@ -298,6 +298,7 @@ class TestDiskTensor(unittest.TestCase): ret = t.bitcast(dtypes.uint16).to("CLANG") + 1 assert ret.tolist() == [2827, 3341, 3855, 4369] + @unittest.skipIf(OSX, "new LLVM has an issue on OSX") def test_bf16_disk_write_read(self): t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32) t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1327537e19..f193bd9bfd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait): # the order of these Ops controls the order of the toposort class Ops(FastEnum): # uops that aren't rendered - SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 + SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702 # TODO: empty continues to exist because of tensor EMPTY = auto() diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1e75338adc..a7163213d1 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"} From f484db0e6344d56f212e24974681f227027979e5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:18:53 +0800 Subject: [PATCH 04/11] dsp cleanups [pr] (#8866) --- examples/benchmark_onnx.py | 11 +++++++--- examples/test_onnx_imagenet.py | 6 +++++- extra/dsp/compile.py | 2 +- extra/dsp/opt.py | 27 ++++++++++++++++++++++++ tinygrad/runtime/ops_clang.py | 20 ++---------------- tinygrad/runtime/ops_dsp.py | 38 ++++++++++++++++++++++++---------- 6 files changed, 70 insertions(+), 34 deletions(-) create mode 100644 extra/dsp/opt.py diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 498e626aa6..333ac9fbba 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,6 +1,5 @@ import sys, onnx, time from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch -from tinygrad.tensor import _from_np_dtype from extra.onnx import OnnxRunner def load_onnx_model(fn): @@ -18,19 +17,25 @@ def load_onnx_model(fn): run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True) return run_onnx_jit, input_shapes, input_types +def get_new_inputs(input_shapes): + #from tinygrad.tensor import _from_np_dtype + #return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + import numpy as np + return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())} + if __name__ == "__main__": run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1]) print("loaded model") for i in range(3): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() print(f"run {i}") run_onnx_jit(**new_inputs) # run 20 times for _ in range(20): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() st = time.perf_counter() out = run_onnx_jit(**new_inputs) diff --git a/examples/test_onnx_imagenet.py b/examples/test_onnx_imagenet.py index a8e5a8c56a..1f27e23e0b 100644 --- a/examples/test_onnx_imagenet.py +++ b/examples/test_onnx_imagenet.py @@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv # https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx # ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx +# QUANT=1 python3 examples/test_onnx_imagenet.py +# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx +# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx + def imagenet_dataloader(cnt=0): input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) @@ -61,7 +65,7 @@ if __name__ == "__main__": assert shape[1:] == (3,224,224), f"shape is {shape}" hit = 0 - for i,(img,y) in enumerate(imagenet_dataloader()): + for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)): p = run_onnx_jit(**{t_name:img}) assert p.shape == (1,1000) t = p.argmax().item() diff --git a/extra/dsp/compile.py b/extra/dsp/compile.py index 93ee4cf9c0..cb3c18a880 100755 --- a/extra/dsp/compile.py +++ b/extra/dsp/compile.py @@ -37,7 +37,7 @@ if __name__ == "__main__": print("mmapped", hex(res)) to_mv(res, 0x10)[1] = 0xaa - from tinygrad.runtime.ops_clang import ClangCompiler + from tinygrad.runtime.ops_dsp import ClangCompiler cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"]) obj = cc.compile(""" diff --git a/extra/dsp/opt.py b/extra/dsp/opt.py new file mode 100644 index 0000000000..fbe35e7ccb --- /dev/null +++ b/extra/dsp/opt.py @@ -0,0 +1,27 @@ +from tinygrad.runtime.ops_dsp import DSPCompiler + +# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py + +if __name__ == "__main__": + compiler = DSPCompiler() + + lib = compiler.compile(""" +typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128))); +typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256))); + +void test(unsigned char *c, unsigned char *a, unsigned char *b) { + HVX_Vector t0 = *(HVX_Vector*)a; + //HVX_VectorPair t1 = *((HVX_VectorPair*)b); + HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B(); + for (int i = 0; i < 128; i++) { + //__builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) + //acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1; + //acc += t0[i] * t1; + unsigned int t1 = ((unsigned int *)b)[i]; + //acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1); + acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1); + } + *((HVX_Vector*)c) = acc; +}""") + + compiler.disassemble(lib) diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 2baf572382..463799f305 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,25 +1,9 @@ -import platform, tempfile, pathlib, subprocess, sys -from tinygrad.helpers import cpu_objdump, capstone_flatdump +import platform, subprocess, sys +from tinygrad.helpers import capstone_flatdump from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.runtime.support.elf import jit_loader from tinygrad.renderer.cstyle import ClangRenderer -# Used by ops_dsp.py -class ClangCompiler(Compiler): - def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): - self.args = ['-march=native'] if args is None else args - self.objdump_tool = objdump_tool - super().__init__(cachekey) - - def compile(self, src:str) -> bytes: - # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here - with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', - '-', '-o', str(output_file.name)], input=src.encode('utf-8')) - return pathlib.Path(output_file.name).read_bytes() - - def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) - class ClangJITCompiler(Compiler): def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 7cac17c6c1..8813bafa45 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,12 +1,11 @@ from __future__ import annotations from typing import Tuple, Any, List -import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys +import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess assert sys.platform != 'win32' -from tinygrad.device import BufferSpec, Compiled, Allocator +from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.ops import Ops, UOp -from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv -from tinygrad.runtime.ops_clang import ClangCompiler +from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import @@ -91,10 +90,23 @@ class DSPAllocator(Allocator): def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes) def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset) -class DSPDevice(Compiled): - def __init__(self, device:str=""): - self.ion_fd = os.open('/dev/ion', os.O_RDONLY) +class ClangCompiler(Compiler): + def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): + self.args = ['-march=native'] if args is None else args + self.objdump_tool = objdump_tool + super().__init__(cachekey) + def compile(self, src:str) -> bytes: + # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', + '-', '-o', str(output_file.name)], input=src.encode('utf-8')) + return pathlib.Path(output_file.name).read_bytes() + + def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) + +class DSPCompiler(ClangCompiler): + def __init__(self): # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem. sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss'] sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections]) @@ -103,15 +115,19 @@ class DSPDevice(Compiled): self.link_ld.flush() compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"] - super().__init__(device, DSPAllocator(self), DSPRenderer(), - ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self)) + return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump') + +class DSPDevice(Compiled): + def __init__(self, device:str=""): + self.ion_fd = os.open('/dev/ion', os.O_RDONLY) + super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self)) fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes())) self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True)) ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes) self.init_dsp() - RPCListner(self).start() + RPCListener(self).start() def open_lib(self, lib): self.binded_lib, self.binded_lib_off = lib, 0 @@ -149,7 +165,7 @@ class DSPDevice(Compiled): qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd) qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0)) -class RPCListner(threading.Thread): +class RPCListener(threading.Thread): def __init__(self, device:DSPDevice): super().__init__() self.device, self.daemon = device, True From a5753095dc176a2bfdd8393a6028cbcd150d3b34 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:32:41 +0800 Subject: [PATCH 05/11] llvm cleanups [pr] (#8867) --- tinygrad/runtime/ops_llvm.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 0dff0e4de3..cd22d83dff 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -11,12 +11,15 @@ def expect(x, err, ret=None): if x: raise RuntimeError(llvm.string_cast(err.contents) if not isinstance(err, str) else err) return ret -HOST_ARCH = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()] -HOST_TRIPLE = {'AArch64': 'aarch64', 'X86': 'x86_64'}[HOST_ARCH] -REQUIRED_COMPONENTS = ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter'] - class LLVMCompiler(Compiler): - def __init__(self, target_machine, opt): + def __init__(self, host_arch:str, opt:bool): + for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{host_arch}{component}')() + triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode() + + target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) + target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if host_arch == 'arm64' else b'', + llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault) + self.pbo = llvm.LLVMCreatePassBuilderOptions() if opt: self.passes = b'default' @@ -48,14 +51,5 @@ class LLVMCompiler(Compiler): class LLVMDevice(Compiled): def __init__(self, device:str): - for component in REQUIRED_COMPONENTS: - getattr(llvm, f'LLVMInitialize{HOST_ARCH}{component}')() - - triple = f'{HOST_TRIPLE}-none-unknown-elf'.encode() - target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) - features = b'+reserve-x18' if platform.machine() == 'arm64' else b'' - target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', features, llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, - llvm.LLVMCodeModelDefault) - - super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), - LLVMCompiler(target_machine, getenv("LLVMOPT")), CPUProgram) + compiler = LLVMCompiler({'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()], bool(getenv("LLVMOPT"))) + super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), compiler, CPUProgram) From b075aefc12c5deba4ff07ec5c58e7311fa30353c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 3 Feb 2025 16:46:19 +0800 Subject: [PATCH 06/11] hotfix: revert llvm host_arch --- tinygrad/runtime/ops_llvm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index cd22d83dff..64583a5404 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -17,7 +17,7 @@ class LLVMCompiler(Compiler): triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode() target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt) - target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if host_arch == 'arm64' else b'', + target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if platform.machine() == 'arm64' else b'', llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault) self.pbo = llvm.LLVMCreatePassBuilderOptions() From b6c617272aba46bd055207faaa49099e85589ed7 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 3 Feb 2025 07:59:11 -0500 Subject: [PATCH 07/11] New schedule.py Order [pr] (#8874) --- tinygrad/engine/schedule.py | 406 ++++++++++++++++++------------------ 1 file changed, 199 insertions(+), 207 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e0aebefb35..98d527b97a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -13,25 +13,70 @@ from tinygrad.device import Buffer # creation can recurse a lot sys.setrecursionlimit(10000) -# **** ScheduleItem return type +# **** schedule simplifier -@dataclass(frozen=True) -class ScheduleItem: - ast: UOp - bufs: tuple[Buffer, ...] - metadata: tuple[Metadata, ...] - @property - def outputs(self) -> tuple[Buffer, ...]: - """Read/write or write only buffers in the schedule.""" - return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) - @property - def inputs(self) -> tuple[Buffer, ...]: - """Read only buffers in the schedule.""" - return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) - @functools.cached_property - def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) +def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: + if not all_int(x.shape): return None + # remove reduce on unmasked const + prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) + ret = x.const_arg + match reduce.arg[0]: + case Ops.ADD: ret *= prshape + case Ops.MUL: ret **= prshape + case Ops.MAX: pass # NOTE: Ops.MAX is passthrough + case _: return None + return reduce.const_like(ret) -# **** Schedule context and big graph +def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) +def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): + new_src = list(alu.src) + for i,s in enumerate(alu.src): + if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src + if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) + +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero + (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ + and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # reduce of size 0 is the identity element + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), + lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # reduce of const is collapsed (TODO: make this a generic rule for stride0) + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), + # COPY(CONST) creates a new CONST on the destination device + (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), + # no COPY to same device, except clone (arg is True) + (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), + lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), + # remove contiguous if we can just view the buffer + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), + lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), + # contiguous/buffer is already contiguous + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), + # support for using a contiguous permuted view instead of the parent view if one exists + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), + (UPat(GroupOp.ALU, name="alu"), replace_contiguous), + # remove CONST/BIND/BUFFER from SINK + (UPat(Ops.SINK, name="root"), + lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), +]) + +remove_movement_ops = merge_views+PatternMatcher([ + # NOTE: movement ops are always applied to base + (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), + # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) + (UPat(Ops.VIEW, name="view"), + lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), +]) + +# **** UOp realization @dataclass(frozen=True) class ScheduleContext: @@ -71,140 +116,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret -# **** AST graph rewrite - -# ** movement ops - -def apply_swizzle(u:UOp) -> UOp: - with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) - -def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: - input_st = ShapeTracker.from_shape(unwrap(src.st).shape) - tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) - prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) - strides = strides_for_shape(rshape) - nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, - v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] - # update input_st and axis - new_input_st = tmp + ShapeTracker(tuple(nv)) - new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) - return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) - -def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: - if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") - output_shape = swizzle_st.reduce(r.axis_arg) - return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) - -def elementwise_view_right(root:UOp) -> UOp|None: - if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None - assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" - assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" - # push the swizzle from src to root - output_swizzle = swizzles[0] - new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) - ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) - return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) - -def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: - assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" - return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) - -# push VIEW to children -view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), - lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), - # STORE is the last child, so we just merge the ShapeTrackers and store the base - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), - # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), - # REDUCE(src.view()) -> REDUCE(src).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), - # ALU(src.view()) -> ALU(src).view() - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), - # double reduce op collapses to a single reduce op - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), -]) - -# ** ScheduleItem context builder - -@dataclass(frozen=True) -class ScheduleItemContext: - var_vals: dict[Variable, int] - sts: set[ShapeTracker] = field(default_factory=set) - bufs: list[UOp] = field(default_factory=list) - -def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: - if (st:=unwrap(x.st)) in ctx.sts: return None - st, var_vals = st.simplify().unbind() - ctx.var_vals.update(var_vals) - ctx.sts.add(st) - return st.to_uop() if st != x.st else None - -def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: - ctx.bufs.append(x) - return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) - -to_si = PatternMatcher([ - # BUFFER -> DEFINE_GLOBAL - (UPat(Ops.BUFFER, name="x"), _append_buf), - # simplify and unbind the final VIEWs - (UPat(Ops.VIEW, name="x"), _append_st_vars), - # don't need SINK on COPY or BUFFER_VIEW - (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))), - # don't need contiguous or assign anymore - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), - # don't need DEVICE anymore - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), - # PRELOAD becomes LOAD - (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), - # once images are loaded they become the base dtype - (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), -]) - -def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): - ctx[var.replace(src=())] = val.arg - return var -unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) - -def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem: - # unbind_vars + push views to edges - sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) - # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL - ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) - # deal with ASSIGN - if len(ctx.assigns) != 0: - assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] - for x in list(sink.toposort)[::-1]: - # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER - if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") - # PRELOAD tells the toposort this kernel should run before ASSIGN - if x.op is Ops.PRELOAD: - assign_preloads[x.buf_uop] = None - # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD - if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: - # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine - if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass - # if it has a single view and it's equal when you shrink a contig, it's fine - elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass - # otherwise, it's not fine - else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - # capture process replay - if CAPTURE_PROCESS_REPLAY: - with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) - return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) - -PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} -if CAPTURE_PROCESS_REPLAY: - @atexit.register - def save_process_replay() -> None: - for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) - -# **** UOp realization - class UPatScheduled(UPat): def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) @@ -352,69 +263,150 @@ break_sched = PatternMatcher([ (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), ]) -# **** schedule simplifier +# **** ScheduleItem creation -def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: - if not all_int(x.shape): return None - # remove reduce on unmasked const - prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) - ret = x.const_arg - match reduce.arg[0]: - case Ops.ADD: ret *= prshape - case Ops.MUL: ret **= prshape - case Ops.MAX: pass # NOTE: Ops.MAX is passthrough - case _: return None - return reduce.const_like(ret) +@dataclass(frozen=True) +class ScheduleItem: + ast: UOp + bufs: tuple[Buffer, ...] + metadata: tuple[Metadata, ...] + @property + def outputs(self) -> tuple[Buffer, ...]: + """Read/write or write only buffers in the schedule.""" + return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) + @property + def inputs(self) -> tuple[Buffer, ...]: + """Read only buffers in the schedule.""" + return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) + @functools.cached_property + def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) -def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): - if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) -def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): - new_src = list(alu.src) - for i,s in enumerate(alu.src): - if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src - if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) +@dataclass(frozen=True) +class ScheduleItemContext: + var_vals: dict[Variable, int] + sts: set[ShapeTracker] = field(default_factory=set) + bufs: list[UOp] = field(default_factory=list) -sym = symbolic_simple+PatternMatcher([ - # UOp with size 0 is zero - (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ - and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # DETACH and CONTIGUOUS_BACKWARD are NOOPs here - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), - # reduce of size 0 is the identity element - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), - lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # reduce of const is collapsed (TODO: make this a generic rule for stride0) - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), - # COPY(CONST) creates a new CONST on the destination device - (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), - # no COPY to same device, except clone (arg is True) - (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), - lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), - # remove cast to image when it's already a contiguous image - (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), - lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), - # remove contiguous if we can just view the buffer - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), - lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), - # contiguous/buffer is already contiguous - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), - # support for using a contiguous permuted view instead of the parent view if one exists - (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), - (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER from SINK - (UPat(Ops.SINK, name="root"), - lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), +def apply_swizzle(u:UOp) -> UOp: + with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) + +def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: + input_st = ShapeTracker.from_shape(unwrap(src.st).shape) + tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) + prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) + strides = strides_for_shape(rshape) + nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, + v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] + # update input_st and axis + new_input_st = tmp + ShapeTracker(tuple(nv)) + new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) + return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) + +def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: + if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") + output_shape = swizzle_st.reduce(r.axis_arg) + return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) + +def elementwise_view_right(root:UOp) -> UOp|None: + if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None + assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" + assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" + # push the swizzle from src to root + output_swizzle = swizzles[0] + new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) + ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) + return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) + +def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: + assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" + assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" + return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) + +# push VIEW to children +view_right = merge_views+PatternMatcher([ + # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), + lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), + # STORE is the last child, so we just merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), + # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), + # REDUCE(src.view()) -> REDUCE(src).view() + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), + # ALU(src.view()) -> ALU(src).view() + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), + # double reduce op collapses to a single reduce op + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -remove_movement_ops = merge_views+PatternMatcher([ - # NOTE: movement ops are always applied to base - (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), - # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) - (UPat(Ops.VIEW, name="view"), - lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), +def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: + if (st:=unwrap(x.st)) in ctx.sts: return None + st, var_vals = st.simplify().unbind() + ctx.var_vals.update(var_vals) + ctx.sts.add(st) + return st.to_uop() if st != x.st else None + +def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: + ctx.bufs.append(x) + return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) + +to_si = PatternMatcher([ + # BUFFER -> DEFINE_GLOBAL + (UPat(Ops.BUFFER, name="x"), _append_buf), + # simplify and unbind the final VIEWs + (UPat(Ops.VIEW, name="x"), _append_st_vars), + # don't need SINK on COPY or BUFFER_VIEW + (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))), + # don't need contiguous or assign anymore + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), + # don't need DEVICE anymore + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), + # PRELOAD becomes LOAD + (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), + # once images are loaded they become the base dtype + (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), ]) +def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): + ctx[var.replace(src=())] = val.arg + return var +unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) + +def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem: + # unbind_vars + push views to edges + sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) + # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL + ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) + # deal with ASSIGN + if len(ctx.assigns) != 0: + assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] + for x in list(sink.toposort)[::-1]: + # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER + if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") + # PRELOAD tells the toposort this kernel should run before ASSIGN + if x.op is Ops.PRELOAD: + assign_preloads[x.buf_uop] = None + # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD + if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass + # if it has a single view and it's equal when you shrink a contig, it's fine + elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass + # otherwise, it's not fine + else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + # capture process replay + if CAPTURE_PROCESS_REPLAY: + with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) + return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) + +PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} +if CAPTURE_PROCESS_REPLAY: + @atexit.register + def save_process_replay() -> None: + for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) + # **** schedule creation and toposort @track_rewrites(named=True) From 73c75d6ee119bfed6987abedde94e9ddf062f398 Mon Sep 17 00:00:00 2001 From: Ali Ladjevardi <71323580+KhanerX@users.noreply.github.com> Date: Mon, 3 Feb 2025 18:20:38 +0330 Subject: [PATCH 08/11] DEFINE_LOCAL variable names start from temp0, not temp1 (#8870) --- tinygrad/codegen/kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 7c6592b912..2fc83e7711 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -620,7 +620,7 @@ class Kernel: if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape)) st = store_st = ShapeTracker.from_shape(local_shape) - local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i + 1}") + local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}") if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle) local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i]) srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store)) @@ -648,7 +648,7 @@ class Kernel: (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() local_size = st_uop.arg.real_size() - local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)+1}") + local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}") local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret))) grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) if op is self.reduceops[-1]: return grouped_reduce From d1aa9f30bc98115beac61e23a47820bae6ee0cc0 Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Tue, 4 Feb 2025 01:15:07 +0800 Subject: [PATCH 09/11] copy onnx_ops into onnx (#8876) * just copy it over * make OnnxOps a global var * some small style stuff * rerun CI but also some small clean up * some comments --- extra/onnx.py | 651 ++++++++++++++++++++++++++++++++++++++++++++-- extra/onnx_ops.py | 606 ------------------------------------------ 2 files changed, 631 insertions(+), 626 deletions(-) delete mode 100644 extra/onnx_ops.py diff --git a/extra/onnx.py b/extra/onnx.py index fbf8e69904..81bb57199a 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,8 +1,8 @@ -from typing import Callable, Any, Sequence -import importlib, functools, dataclasses -from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv, DEBUG, all_same -from tinygrad.dtype import DType, ConstType, dtypes +from typing import Any, Sequence, cast, Literal, Callable +import dataclasses, functools, io, math, types +from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr +from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple +from tinygrad.dtype import DType, ConstType, dtypes, ImageDType from tinygrad.device import is_dtype_supported # ***** protobuf parsing ****** @@ -111,11 +111,11 @@ limit = int(getenv("ONNXLIMIT", "-1")) class OnnxRunner: def __init__(self, model: ModelProto): # parse model protobuf - self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node) + self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node) self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad Tensor.training = True if self.is_training else False Tensor.no_grad = False if self.is_training else True - self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer} + self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}} self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output} self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}) @@ -123,14 +123,7 @@ class OnnxRunner: self.opset_version = model.opset_import[0].version self.variable_dims: dict[str, int] = {} - # TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import - # TODO: clean up opset stuff after moving extra.onnx_ops here - self.onnx_ops_module = importlib.import_module('extra.onnx_ops') - self.onnx_ops = { - **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", - "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", - "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, - } + self.onnx_ops = onnx_ops def _parse_input(self, name: str, value: Any, spec: OnnxValue): if spec.is_optional and value is None: return None @@ -148,9 +141,8 @@ class OnnxRunner: return tensor def _dispatch_op(self, op, inps, opts): - if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts) - if hasattr(self.onnx_ops_module, op): - fxn = getattr(self.onnx_ops_module, op) + if op in self.onnx_ops: + fxn = self.onnx_ops[op] if isinstance(fxn, dict): for k in sorted(fxn.keys()): if k <= self.opset_version: @@ -165,7 +157,7 @@ class OnnxRunner: self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) for node in self.graph_nodes: - inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)] + inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)] opts = node.opts # provide additional opts @@ -184,4 +176,623 @@ class OnnxRunner: Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad return {name:self.graph_values[name] for name in node.outputs} Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad - return {name:self.graph_values[name] for name in self.graph_outputs} \ No newline at end of file + return {name:self.graph_values[name] for name in self.graph_outputs} + +#################### +##### ONNX OPS ##### +#################### +def get_onnx_ops(): + # ***** helper functions ***** + def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) + + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) + + AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] + + def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): + i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) + if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) + o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] + return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) + + def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + + def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): + if axis < 0: axis += x.ndim + if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) + if block_size == 0: + shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) + return scale.reshape(shape), zero_point.reshape(shape) + return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) + + def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + + def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + + def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): + R = R.detach() + groups = len(inputs) // input_group_size + ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] + return tuple(flatten(zip(*ret))) + return ___wrapper + return __decorator + + # ***** Property/Graph Ops ***** + def Identity(x:Tensor): return x + def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): + if value is not None: return value + if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) + if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) + if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) + if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) + if value_string is not None or value_strings is not None and sparse_value is not None: + raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') + + def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) + + def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] + if pixel_format == "RGB": return Tensor(np.array(img)) + if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + + def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + + def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) + def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) + def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) + + def Size(data:Tensor): return data.numel() + def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + + # ***** Unary Ops (math) ***** + def Not(x:Tensor): return x.logical_not() + def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): + return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) + + # ***** Unary Ops (activation) ***** + def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) + def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) + Softmax = {1:Softmax_1, 13:Softmax_13} + def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) + def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) + def FastGelu(x:Tensor, bias:Tensor|None=None): + # this is tanh approximated + return (x + bias).gelu() if bias is not None else x.gelu() + # TODO: fix this + def PRelu(X:Tensor, slope:Tensor): + slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope + return (X > 0).where(X, X * slope) + def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) + def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) + def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) + def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() + + # ***** Unary Ops (broadcasted) ***** + def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) + def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int + def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) + def Less(x:Tensor,y:Tensor): return x < y + def LessOrEqual(x:Tensor,y:Tensor): return x <= y + def Greater(x:Tensor,y:Tensor): return x > y + def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y + def Equal(x:Tensor,y:Tensor): return x == y + def And(x:Tensor,y:Tensor): return (x==y).where(x, False) + def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) + def BitwiseAnd(x:Tensor,y:Tensor): return x & y + def BitwiseOr(x:Tensor,y:Tensor): return x | y + def BitwiseXor(x:Tensor,y:Tensor): return x ^ y + def BitwiseNot(x:Tensor): return ~x + + # ***** Casting Ops ***** + # TODO: saturate + def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) + def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + + # ***** Reduce Ops ***** + def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) + def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) + def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) + def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) + def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) + def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) + def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() + def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() + def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): + return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() + def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) + def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) + + # ***** Movement Ops ***** + def Reshape(data:Tensor, shape:list[int], allowzero:int=0): + return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) + def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) + def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) + def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) + def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) + + # TODO: add test for when axes is None + def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) + def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) + + def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) + def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) + def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): + axes = axes or list(range(data.ndim)) + steps = steps or [1]*data.ndim + slices = [slice(0,x,1) for x in data.shape] + for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) + return data[tuple(slices)] + + def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): + sz = data.shape[axis] + if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] + return data.split(split, axis) + + def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, + mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): + value = constant_value or value + axes = axes or list(range(x.ndim)) + real_pads = [0] * (x.ndim*2) + for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] + return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) + + def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + + # ***** Processing Ops ***** + def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, + dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): + return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), + ceil_mode=ceil_mode, count_include_pad=count_include_pad) + + def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, + storage_order:int=0, strides:list[int]|int=1): + ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) + # tests expect indices with int64 dtype + # TODO: if there are repeated values, this is wrong + indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) + return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices + + def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) + + def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, + kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, + strides:list[int]|int=1): + input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:]) + strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding)) + if output_shape is not None: # we pad according to output_shape + pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in + zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad) + if pads is None: # we generate pads + output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))] + pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))] + pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2 + pads = _onnx_pads_to_tiny_pads(pads) + return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding) + + def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1): + pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides)) + out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] + ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh) + if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") + return ret.pad(_onnx_pads_to_tiny_pads(pads)) + + def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) + def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) + + def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret + + def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) + + def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) + + def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) + + def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + def _apply_nearest_mode(index: Tensor, input_dim, mode: str): + if mode == "round_prefer_floor": index = (index - 0.5).ceil() + elif mode == "round_prefer_ceil": index = (index + 0.5).floor() + elif mode in ["floor", "ceil"]: index = getattr(index, mode)() + else: raise ValueError(f"invalid {nearest_mode=}") + return index.cast(dtypes.int32).clip(0, input_dim-1) + def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode): + # TODO: needs more testing, not confident in this + # NOTE: their reference implementation differ from the implementation in their reference docs + # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py + # https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize + output_dim = scale_dim * input_dim + if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5 + elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0]) + elif mode == "asymmetric": index = index / scale_dim + elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5]) + elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5 + elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) + else: raise ValueError(f"invalid {coordinate_transformation_mode=}") + return index.clip(0, input_dim-1) + + scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):]) + # we pre permute the axes and permute back after resize + axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]), + perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) + X = X.permute(*perm) + + if sizes is not None: + if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: + scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max + scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2 + sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)] + else: + scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)] + else: + sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)] + regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2) + + # NOTE: this transformation makes it so that we can't just call Tensor.interpolate + # in Tensor.interpolate, we use indexes without any transformation + indexes = [] + for shape, size, scale, region in zip(input_shape, sizes, scales, regions): + indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode)) + + if mode == "nearest": + indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)] + X = X[(..., *Tensor.meshgrid(*indexes))] + if mode == "linear": + expand = list(X.shape) + for i in range(-len(sizes), 0): + reshape, index = [1] * X.ndim, indexes[i] + reshape[i] = expand[i] = sizes[i] + low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] + X = X.gather(i, low).lerp(X.gather(i, high), perc) + if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") + return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X + def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated + + # ***** Neural Network Ops ***** + # TODO: try to factor out common implementations for these ops + # https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 + def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() + + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + invstd = (input_var + epsilon).rsqrt() + return X.batchnorm(scale, B, input_mean, invstd) + def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + axis = tuple(range(2, x.ndim)) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) + def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() + def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) + def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) + def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): + x = x + skip + bias + return x.layernorm(eps=epsilon) * gamma + beta, None, None, x + def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, + segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, + position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization + assert (segment_ids is None) is (segment_embedding is None) + assert mask is None and not mask_index_type, "functionality not supported yet" # TODO + input_shape = input_ids.shape + seq_length = input_shape[1] + compute_seg_emb = (segment_embedding is not None and segment_ids is not None) + vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] + type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) + + def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: + return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight + + # bert embedding layer + if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) + wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) + pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) + seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None + + embedding_sum = wrd_embedding_res + pos_embedding_res + if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res + out = embedding_sum.layernorm(eps=epsilon) * gamma + beta + return out, None, embedding_sum + + def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(depth[0] if isinstance(depth, list) else depth) + indices = (indices < 0).where(indices+depth, indices) + return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + + def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) + def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + + # Reimplemented here because you need legacy RNG for passing ONNX tests. + def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. + mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) + return data * mask * (1/(1.0 - ratio)), mask + # 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx + def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) + Dropout = {6:Dropout_6, 7:Dropout_7} + + def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + + def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) + def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + + def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + + def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, + relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, + mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, + qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention + assert num_heads is not None # required + assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) + assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \ + "functionality not supported yet" # TODO strange params + hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) + + if unidirectional: # gpt-style + assert hidden_size == v_hidden_size + xqkv = x.linear(weights, bias) + xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] + else: # bert-style + wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] + bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None + xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] + xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] + + if past is not None: + xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) + present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + + def attn(query, key, value, attn_mask): + query_length, key_length = query.shape[-2], key.shape[-2] + cdim = max(query_length, key_length) + 1 + attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) + # This is where Tensor.scaled_dot_product_attention differs: + causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] + masked = Tensor.where(causal_mask, attn_weights, -math.inf) + if attn_mask is not None: masked = masked + attn_mask + return masked.softmax(-1) @ value + + bsz, _, seq_len, _ = xq.shape + out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) + return out, present if past is not None else out + + # ***** Indexing Ops ***** + def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + + def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + x_sh = list(x.shape) + ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] + def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + + def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) + def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + + def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) + def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + + def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + if axis < 0: axis += inp.ndim + con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + + # ***** Quantization Ops ***** + def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + + def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + + def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point: Tensor|int, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + + def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point:Tensor|int) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) + + def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): + assert channels_last == 0, "unsure what this does" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) + + def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + + def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + + # ***** Training Ops ***** + # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested + # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code + @_onnx_training(3) + def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): + X, G, H = (i.detach() for i in inputs) + grad = norm_coefficient * X + G + H.assign(H + grad.square()) + up = grad / (H.sqrt() + epsilon) + r = R / (1 + T * decay_factor) + X.assign(X.detach() - r * up) + return [X, H] + + @_onnx_training(4) + def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam + X, G, V, H = inputs + G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches + X.grad = norm_coefficient * X.detach() + G + opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) + opt.m, opt.v, opt.lr = [V], [H], R + # need no-op for m_hat and v_hat if T == 0 + if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() + else: + # `T-1` since it's applied again at the start of `_step` + opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) + opt.step() + X = (1 - norm_coefficient_post) * X + return [X, V, H] + + @_onnx_training(3) + def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + from tinygrad.nn.optim import SGD + X, G, V = inputs + G, V = G.detach(), V.detach() + X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) + opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov")) + opt.b, opt.lr = [V], R + opt.step() + return [X, V] + + def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): + intermediate_tensors[y].backward() + return tuple([t.grad for t in inputs]) + + return { + # Tensor ops + **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", + "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", + "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, + # Implemented ops + **{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()}, + # Version ops + **{name:obj for name,obj in locals().items() if isinstance(obj, dict)}, + } + +onnx_ops = get_onnx_ops() diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py deleted file mode 100644 index 165de3154b..0000000000 --- a/extra/onnx_ops.py +++ /dev/null @@ -1,606 +0,0 @@ -import functools, io, math -from typing import cast, Literal -from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr -from tinygrad.dtype import ImageDType, dtypes, DType -from tinygrad.helpers import prod, flatten, make_tuple -from extra.onnx import dtype_parse, _cached_to_python_const -import numpy as np - -# ***** Property/Graph Ops ***** -def Identity(x:Tensor): return x -def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, - value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): - if value is not None: return value - if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) - if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) - if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) - if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) - if value_string is not None or value_strings is not None and sparse_value is not None: - raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') - -def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) - -def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] - if pixel_format == "RGB": return Tensor(np.array(img)) - if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - -def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - -def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) -def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) -def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) - -def Size(data:Tensor): return data.numel() -def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) - -# ***** Unary Ops (math) ***** -def Not(x:Tensor): return x.logical_not() -def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): - return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) - -# ***** Unary Ops (activation) ***** -def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) -def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) -Softmax = {1:Softmax_1, 13:Softmax_13} -def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) -def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) -def FastGelu(x:Tensor, bias:Tensor|None=None): - # this is tanh approximated - return (x + bias).gelu() if bias is not None else x.gelu() -# TODO: fix this -def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope - return (X > 0).where(X, X * slope) -def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) -def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) -def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) -def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - -# ***** Unary Ops (broadcasted) ***** -def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) -def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int -def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) -def Less(x:Tensor,y:Tensor): return x < y -def LessOrEqual(x:Tensor,y:Tensor): return x <= y -def Greater(x:Tensor,y:Tensor): return x > y -def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y -def Equal(x:Tensor,y:Tensor): return x == y -def And(x:Tensor,y:Tensor): return (x==y).where(x, False) -def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) -def BitwiseAnd(x:Tensor,y:Tensor): return x & y -def BitwiseOr(x:Tensor,y:Tensor): return x | y -def BitwiseXor(x:Tensor,y:Tensor): return x ^ y -def BitwiseNot(x:Tensor): return ~x - -# ***** Casting Ops ***** -# TODO: saturate -def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) -def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - -# ***** Reduce Ops ***** -def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) -def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) -def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) -def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) -def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() -def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() -def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() -def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): - return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - -# ***** Movement Ops ***** -def Reshape(data:Tensor, shape:list[int], allowzero:int=0): - return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) -def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) -def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) -def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) -def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) - -# TODO: add test for when axes is None -def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) -def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - -def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) -def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) -def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): - axes = axes or list(range(data.ndim)) - steps = steps or [1]*data.ndim - slices = [slice(0,x,1) for x in data.shape] - for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) - return data[tuple(slices)] - -def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0): - sz = data.shape[axis] - if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] - return data.split(split, axis) - -def _onnx_pads_to_tiny_pads(pads): - # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) - return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) -def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, - mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): - value = constant_value or value - axes = axes or list(range(x.ndim)) - real_pads = [0] * (x.ndim*2) - for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] - return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) - -def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - -# ***** Processing Ops ***** -AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] -def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] -def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): - i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) - if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) - o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)] - return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) - -def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, - dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), - ceil_mode=ceil_mode, count_include_pad=count_include_pad) - -def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, - storage_order:int=0, strides:list[int]|int=1): - ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) - # tests expect indices with int64 dtype - # TODO: if there are repeated values, this is wrong - indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) - return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices - -def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, - padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) - -def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, - strides:list[int]|int=1): - input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:]) - strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding)) - if output_shape is not None: # we pad according to output_shape - pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in - zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad) - if pads is None: # we generate pads - output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))] - pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))] - pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2 - pads = _onnx_pads_to_tiny_pads(pads) - return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding) - -def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides)) - out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] - ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh) - if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") - return ret.pad(_onnx_pads_to_tiny_pads(pads)) - -def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) -def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) - -def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - -def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - -def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - -def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) - -def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): - def _apply_nearest_mode(index: Tensor, input_dim, mode: str): - if mode == "round_prefer_floor": index = (index - 0.5).ceil() - elif mode == "round_prefer_ceil": index = (index + 0.5).floor() - elif mode in ["floor", "ceil"]: index = getattr(index, mode)() - else: raise ValueError(f"invalid {nearest_mode=}") - return index.cast(dtypes.int32).clip(0, input_dim-1) - def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode): - # TODO: needs more testing, not confident in this - # NOTE: their reference implementation differ from the implementation in their reference docs - # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py - # https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize - output_dim = scale_dim * input_dim - if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5 - elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0]) - elif mode == "asymmetric": index = index / scale_dim - elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5]) - elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5 - elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) - else: raise ValueError(f"invalid {coordinate_transformation_mode=}") - return index.clip(0, input_dim-1) - - scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):]) - # we pre permute the axes and permute back after resize - axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]), - perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) - X = X.permute(*perm) - - if sizes is not None: - if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: - scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max - scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2 - sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)] - else: - scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)] - else: - sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)] - regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2) - - # NOTE: this transformation makes it so that we can't just call Tensor.interpolate - # in Tensor.interpolate, we use indexes without any transformation - indexes = [] - for shape, size, scale, region in zip(input_shape, sizes, scales, regions): - indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode)) - - if mode == "nearest": - indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)] - X = X[(..., *Tensor.meshgrid(*indexes))] - if mode == "linear": - expand = list(X.shape) - for i in range(-len(sizes), 0): - reshape, index = [1] * X.ndim, indexes[i] - reshape[i] = expand[i] = sizes[i] - low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] - X = X.gather(i, low).lerp(X.gather(i, high), perc) - if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") - return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X -def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated - -# ***** Neural Network Ops ***** -# TODO: try to factor out common implementations for these ops -# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 -def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon).rsqrt() - return X.batchnorm(scale, B, input_mean, invstd) -def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - axis = tuple(range(2, x.ndim)) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) -def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() -def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) -def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) -def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): - x = x + skip + bias - return x.layernorm(eps=epsilon) * gamma + beta, None, None, x -def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, - segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, - position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization - assert (segment_ids is None) is (segment_embedding is None) - assert mask is None and not mask_index_type, "functionality not supported yet" # TODO - input_shape = input_ids.shape - seq_length = input_shape[1] - compute_seg_emb = (segment_embedding is not None and segment_ids is not None) - vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] - type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) - - def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: - return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight - - # bert embedding layer - if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) - wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) - pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) - seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None - - embedding_sum = wrd_embedding_res + pos_embedding_res - if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res - out = embedding_sum.layernorm(eps=epsilon) * gamma + beta - return out, None, embedding_sum - -def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(depth[0] if isinstance(depth, list) else depth) - indices = (indices < 0).where(indices+depth, indices) - return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - -def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) -def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) - -# Reimplemented here because you need legacy RNG for passing ONNX tests. -def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask -# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx -def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) -Dropout = {6:Dropout_6, 7:Dropout_7} - -def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) - -def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) -def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - -def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - -def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, - relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, - mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, - qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention - assert num_heads is not None # required - assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) - assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \ - "functionality not supported yet" # TODO strange params - hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) - - if unidirectional: # gpt-style - assert hidden_size == v_hidden_size - xqkv = x.linear(weights, bias) - xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] - else: # bert-style - wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] - bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None - xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] - xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] - - if past is not None: - xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) - present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) - - def attn(query, key, value, attn_mask): - query_length, key_length = query.shape[-2], key.shape[-2] - cdim = max(query_length, key_length) + 1 - attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) - # This is where Tensor.scaled_dot_product_attention differs: - causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] - masked = Tensor.where(causal_mask, attn_weights, -math.inf) - if attn_mask is not None: masked = masked + attn_mask - return masked.softmax(-1) @ value - - bsz, _, seq_len, _ = xq.shape - out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) - return out, present if past is not None else out - -# ***** Indexing Ops ***** -def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - -def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] -def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - -def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) -def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - -def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) -def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) - -def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - if axis < 0: axis += inp.ndim - con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - -# ***** Quantization Ops ***** -def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) - -def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): - if axis < 0: axis += x.ndim - if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) - if block_size == 0: - shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) - return scale.reshape(shape), zero_point.reshape(shape) - return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) - -def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() - -def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - -def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): - adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] - return op(*adjusted_inputs, **opts) - -def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in quantized int - out = _op_integer(op, inputs, zero_points, **opts) - assert dtypes.is_int(out.dtype), "quantized op should've done math in int" - out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - -def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): - # op execution is done in float32 - dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] - out = op(*dequantized_inputs, **opts) - assert dtypes.is_float(out.dtype), "op should've done math in float" - out_quantized = (out / out_scale).round() + out_zero_point - return _clamp_cast(out_quantized, out_zero_point.dtype) - -def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point: Tensor|int, B:Tensor|None=None, **opts): - return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) - -def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point:Tensor|int) -> Tensor: - return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) - -def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) - -def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last == 0, "unsure what this does" - return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) - -def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: - return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) - -def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: - return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) - -# ***** Training Ops ***** -# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested -# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code -def _onnx_training(input_group_size): - def __decorator(func): - def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): - R = R.detach() - groups = len(inputs) // input_group_size - ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] - return tuple(flatten(zip(*ret))) - return ___wrapper - return __decorator - -@_onnx_training(3) -def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): - X, G, H = (i.detach() for i in inputs) - grad = norm_coefficient * X + G - H.assign(H + grad.square()) - up = grad / (H.sqrt() + epsilon) - r = R / (1 + T * decay_factor) - X.assign(X.detach() - r * up) - return [X, H] - -@_onnx_training(4) -def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): - from tinygrad.nn.optim import Adam as TinyAdam - X, G, V, H = inputs - G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches - X.grad = norm_coefficient * X.detach() + G - opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) - opt.m, opt.v, opt.lr = [V], [H], R - # need no-op for m_hat and v_hat if T == 0 - if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() - else: - # `T-1` since it's applied again at the start of `_step` - opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) - opt.step() - X = (1 - norm_coefficient_post) * X - return [X, V, H] - -@_onnx_training(3) -def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): - from tinygrad.nn.optim import SGD - X, G, V = inputs - G, V = G.detach(), V.detach() - X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) - opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov")) - opt.b, opt.lr = [V], R - opt.step() - return [X, V] - -def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): - intermediate_tensors[y].backward() - return tuple([t.grad for t in inputs]) From cce26009f0b324d2e17baefc4afbb730b2df78e2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 3 Feb 2025 12:54:18 -0500 Subject: [PATCH 10/11] simplify pow to not call cos (#8877) use %2 instead of cos to detect even numbers --- test/test_const_folding.py | 5 ++--- tinygrad/tensor.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index dfffca8989..b78faee145 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4)) def test_literal_one_pow(self): _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4])) - # this fails because of DETACH, it shouldn't - # update: passes after CONST(VIEW(DEVICE)) in tensor + # TODO: pow simplification def test_tensor_one_pow(self): - _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) + _check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) # folds advance indexing into basic indexing class TestIndexingConstFolding(unittest.TestCase): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8d048eb431..de142aa06f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait): if not base.is_floating_point(): raise RuntimeError("base needs to be float") # start with b ** e = exp(e * log(b)) ret = base.abs().log().mul(exponent).exp() - # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent) + # correct sign of negative base with odd exponent negative_base = (base < 0).detach().where(1, 0) # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent - correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1) + correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base) # inject nan for negative base and non-integer exponent inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1) # apply correct_sign inject_nan, and fix 0 ** 0 = 1 From ec447a31e7609f17ba160d789336544d288e725b Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 3 Feb 2025 14:39:08 -0500 Subject: [PATCH 11/11] factor out get_axis in multi [pr] (#8878) ALU/REDUCE_AXIS/RESHAPE/PERMUTE can change axis. prereq to move this logic to ops.py --- tinygrad/engine/multi.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 4803127690..17a3259b04 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -43,13 +43,28 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites +def get_axis(root:UOp): + if root.op is Ops.MULTI: return root.arg[0] + # NOTE: they all have to share an axis, we always choose [-1] + if root.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in root.src if x.axis is not None])) else None + src_axis = get_axis(root.src[0]) + if root.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in root.arg[1] else src_axis + if root.op is Ops.RESHAPE: + if src_axis is None: return None + arg_acc:list[sint] = list(itertools.accumulate(root.arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + return len(arg_acc) - arg_acc[::-1].index(prod(root.src[0].shape[:src_axis])) - 1 + if root.op is Ops.PERMUTE: return root.arg.index(src_axis) if src_axis is not None else None + raise NotImplementedError("rest should be passthrough") + def alu_multi(root:UOp): msrcs = root.src assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - # NOTE: they all have to share an axis, we always choose [-1] - axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) + axis = get_axis(root) + bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None srcs:list[list[UOp]] = [] not_all_real = not all(all(mlb.real) for mlb in msrcs) new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real @@ -64,28 +79,24 @@ def alu_multi(root:UOp): return UOp.multi(*new_lbs, axis=axis, real=new_real) def reduce_multi(root:UOp, multi:UOp): - op, axis = root.arg + (op, axis), new_axis = root.arg, get_axis(root) if multi.axis is not None and multi.axis in axis: # all-reduce on sharded axes reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] # if all partitions are real, do all_reduce - if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None) + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=new_axis) # only one partition is real, keep it - return UOp.multi(*reduced_parts, axis=None, real=multi.real) + return UOp.multi(*reduced_parts, axis=new_axis, real=multi.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real) + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=new_axis, real=multi.real) def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): - arg = root.arg - if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real) + arg, new_axis = root.arg, get_axis(root) + if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real) assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" - arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1 assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] @@ -109,7 +120,7 @@ def pad_multi(root:UOp, multi:UOp): def permute_multi(root:UOp, multi:UOp): # all permutes supported! - return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real) + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=get_axis(root), real=multi.real) def shrink_multi(root:UOp, multi:UOp): assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \