From e84d089ef1e1b3ba6e031c6982c0bc22dba1afd4 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:04:27 +0200 Subject: [PATCH] delete ReduceOps, only use REDUCE_AXIS (#7667) --- extra/optimization/helpers.py | 2 +- .../external_benchmark_multitensor_allreduce.py | 8 ++++---- test/test_multitensor.py | 14 +++++++------- test/test_uops.py | 2 +- test/unit/test_pattern_matcher.py | 2 +- test/unit/test_verify_ast.py | 10 +++++----- tinygrad/engine/fuse.py | 12 ++++++------ tinygrad/engine/lazy.py | 12 ++++++------ tinygrad/engine/schedule.py | 5 ++--- tinygrad/function.py | 12 ++++++------ tinygrad/multi.py | 5 ++--- tinygrad/ops.py | 12 +++--------- 12 files changed, 44 insertions(+), 52 deletions(-) diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index c433359d6c..342c4c27b6 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -2,7 +2,7 @@ from typing import Tuple from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps +from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index b20da756a3..c139a0fc3b 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -1,7 +1,7 @@ import time from tinygrad import Tensor, Device, GlobalCounters, TinyJit from tinygrad.engine.lazy import LazyBuffer -from tinygrad.ops import ReduceOps +from tinygrad.ops import Ops from tinygrad.multi import MultiLazyBuffer, all_reduce from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule @@ -14,7 +14,7 @@ def realize(x: Union[LazyBuffer, List[LazyBuffer]]): for lb in x: Device[lb.device].synchronize() def test(devs: List[str], N: int, iters:int = 10): - def _wrapped(op: ReduceOps, t: Tensor) -> Tensor: + def _wrapped(op: Ops, t: Tensor) -> Tensor: return Tensor(MultiLazyBuffer(all_reduce(op, t.lazydata.lbs), 0), device=devs) _jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped @@ -24,7 +24,7 @@ def test(devs: List[str], N: int, iters:int = 10): realize(lbs) GlobalCounters.reset() start = time.time() - realize(_jitted(ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs) + realize(_jitted(Ops.ADD, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs) if i < 0: continue # warm up jit i_secs = time.time() - start @@ -68,4 +68,4 @@ def main(): print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index c7ab8f9de5..9e442188e5 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,7 +1,7 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops +from tinygrad.ops import MetaOps, BinaryOps, Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule @@ -31,7 +31,7 @@ N = 128 def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() - b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0)) + b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0)) b.realize() return aa, b @@ -145,14 +145,14 @@ class TestMultiTensor(unittest.TestCase): np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2) @given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), - strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.REDUCE_MAX)), + strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)), strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1))) def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign): X = Tensor.rand(N*N).reshape(N, N).mul(sign) n = X.numpy() X.shard_(devices, shard_axis) - f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis), - ReduceOps.REDUCE_MAX: lambda x: x.max(reduce_axis)}[rop] + f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), + Ops.MAX: lambda x: x.max(reduce_axis)}[rop] fX = f(X) fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) @@ -197,9 +197,9 @@ class TestMultiTensor(unittest.TestCase): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): - a = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0)) + a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) with Context(RING=2): - b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0)) + b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() diff --git a/test/test_uops.py b/test/test_uops.py index b213eb55af..13dbdb7d2b 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device -from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 +from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 45b2243de3..0aa46137a8 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -1,6 +1,6 @@ import unittest, itertools from tinygrad.dtype import dtypes -from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401 +from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401 from tinygrad.ops import PatternMatcher, UPat class TestPatternMatcher(unittest.TestCase): diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index 361056aec0..b114c0a773 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -4,7 +4,7 @@ import unittest from tinygrad import Tensor from tinygrad.codegen.kernel import Kernel from tinygrad.helpers import DEBUG -from tinygrad.ops import UOp, Ops, ReduceOps, print_uops +from tinygrad.ops import UOp, Ops, print_uops from tinygrad.codegen.kernel import verify_ast from tinygrad.shape.shapetracker import ShapeTracker from tinygrad import dtypes @@ -48,7 +48,7 @@ class TestVerifyAST(unittest.TestCase): def test_no_implicit_broadcasting(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) - b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.REDUCE_MAX, (1,))) + b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,))) st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) @@ -62,14 +62,14 @@ class TestVerifyAST(unittest.TestCase): def test_reduce_store(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) + r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) def test_reduce_add_store(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) + r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) @@ -84,7 +84,7 @@ class TestVerifyAST(unittest.TestCase): def test_assert_swizzle(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) + r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,))) st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a) with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st) diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 91c3ca078a..04dd366407 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -1,7 +1,7 @@ from collections import defaultdict, deque from typing import Set, Tuple, List, Dict, DefaultDict from tinygrad.device import Buffer -from tinygrad.ops import GroupOp, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.lazy import LazyBuffer @@ -19,7 +19,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa return group.setdefault(tr) for tr_next in children[tr]: # max one reduceop per kernel - if tr_next.op in GroupOp.Reduce: return group.setdefault(r) + if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r) # can only fuse contiguous if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r) _recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache) @@ -31,7 +31,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], ch if (p:=rc_parents.pop()) in cache: continue cache.add(p) # max one reduceop per kernel - if p.op in GroupOp.Reduce: return {} + if p.op is Ops.REDUCE_AXIS: return {} rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) # search descendants of the reduceop that can cleanly group descendants: Dict[LazyBuffer, None] = {} @@ -49,7 +49,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu reduce_for_op: Dict[LazyBuffer, UOp] = {} reduce_of_const: List[UOp] = [] for r in allbufs: - if r in realizes or r.op not in GroupOp.Reduce: continue + if r in realizes or r.op is not Ops.REDUCE_AXIS: continue group: Dict[LazyBuffer, None] = {} _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={}) # max one reduceop per kernel @@ -77,7 +77,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu if len(st_childs) > 1: break if st.size != st_childs[0].size: break st = st + st_childs[0] - if not st.contiguous or tr_next.op in GroupOp.Reduce: break + if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break tr = tr_next # don't cast to higher size before store (tr cannot be realized if forced_realize) if tr.op is Ops.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize: @@ -86,7 +86,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu realizes[tr] = None rbuf = buf_uops[r.buffer] reduce_for_op.update((tr, rbuf) for tr in group) - if FUSE_ARANGE and r.op is Ops.SUM and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf) + if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf) # fuse double reduces with no other child if FUSE_CONV_BW: diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 438d71014b..21fa3f543d 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Union, Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE -from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu, python_alu, REDUCE_ALU +from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -177,19 +177,19 @@ class LazyBuffer(MathTrait): assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}" axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) if len(axis) == 0: return self - return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,)) + return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,)) def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer: new_shape = self.st.reduce(axis) # TODO: this logic should move to the scheduler - if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(REDUCE_ALU[op], self.dtype), new_shape) + if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape) # const folding # TODO: fold this for symbolic? if self.is_unrealized_unmasked_const() and all_int(self.shape): - if op is ReduceOps.SUM: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape) - if op is ReduceOps.PROD: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape) - if op is ReduceOps.REDUCE_MAX: return self.const_with_shape(self.base.arg, new_shape) + if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape) + if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape) + if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape) # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \ diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 67a2243600..2df75af70d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -69,8 +69,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop())) # everything else needs sources src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs) - if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg) - elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src) + if buf.op in {Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ret = UOp(buf.op, dtype, src, buf.arg) elif buf.op is Ops.ASSIGN: ctx.assigns.add(ubuf) ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) @@ -82,7 +81,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce ctx.lazybufs[b] = buf # things for fuse.py allbufs[buf] = None - if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None + if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None for x in buf.srcs: if x.base.realized is None: children[x.base][buf] = None return ret diff --git a/tinygrad/function.py b/tinygrad/function.py index 26cd54d1f4..01c597238c 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -3,7 +3,7 @@ import math from typing import Tuple, Optional from tinygrad.helpers import argsort from tinygrad.dtype import dtypes, DType, sum_acc_dtype -from tinygrad.ops import ReduceOps, resolve, sint +from tinygrad.ops import Ops, resolve, sint from tinygrad.tensor import Function from tinygrad.engine.lazy import LazyBuffer @@ -142,13 +142,13 @@ class Where(Function): class Sum(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, axis) + return x.r(Ops.ADD, axis) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Prod(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.r(ReduceOps.PROD, axis) + self.x, self.ret = x, x.r(Ops.MUL, axis) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -156,13 +156,13 @@ class Prod(Function): class Max(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret, self.axis = x, x.r(ReduceOps.REDUCE_MAX, axis), axis + self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype) - div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape) + div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape) return (max_is_1s/div) * grad_output.expand(self.x.shape) # ************* movement ops ************* @@ -174,7 +174,7 @@ class Expand(Function): return x.expand(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype) + return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 0e0381039b..975932a3ca 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -3,14 +3,13 @@ from typing import Optional, Tuple, List, Dict import functools, itertools, operator from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv from tinygrad.dtype import DType -from tinygrad.ops import REDUCE_ALU, Ops, MathTrait +from tinygrad.ops import Ops, MathTrait from tinygrad.engine.lazy import LazyBuffer from tinygrad.shape.shapetracker import sint -def all_reduce(op: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]: +def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]: assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}" assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined" - bop = REDUCE_ALU[op] n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape) # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b599423b85..a000905ce3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -122,9 +122,6 @@ class Ops(FastEnum): # reduce REDUCE_AXIS = auto() - # ReduceOps - SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702 - # helper ops GEP = auto() VECTORIZE = auto() @@ -173,7 +170,6 @@ class GroupOp: Ternary = {Ops.WHERE, Ops.MULACC} ALU = set.union(Unary, Binary, Ternary) - Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX} Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} # meta ops @@ -187,9 +183,7 @@ class GroupOp: UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} # TODO: remove this? -UnaryOps = BinaryOps = ReduceOps = MetaOps = TernaryOps = Ops - -REDUCE_ALU: Dict[Ops, Ops] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX} +UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt) @@ -343,7 +337,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False)) - def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis)) + def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) # *** uop movement ops *** @@ -815,7 +809,7 @@ spec = PatternMatcher([ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), - (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()), + (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),