delete ReduceOps, only use REDUCE_AXIS (#7667)

This commit is contained in:
qazal
2024-11-13 13:04:27 +02:00
committed by GitHub
parent 217c006103
commit e84d089ef1
12 changed files with 44 additions and 52 deletions

View File

@@ -2,7 +2,7 @@
from typing import Tuple
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View

View File

@@ -1,7 +1,7 @@
import time
from tinygrad import Tensor, Device, GlobalCounters, TinyJit
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.ops import ReduceOps
from tinygrad.ops import Ops
from tinygrad.multi import MultiLazyBuffer, all_reduce
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -14,7 +14,7 @@ def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
for lb in x: Device[lb.device].synchronize()
def test(devs: List[str], N: int, iters:int = 10):
def _wrapped(op: ReduceOps, t: Tensor) -> Tensor:
def _wrapped(op: Ops, t: Tensor) -> Tensor:
return Tensor(MultiLazyBuffer(all_reduce(op, t.lazydata.lbs), 0), device=devs)
_jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
@@ -24,7 +24,7 @@ def test(devs: List[str], N: int, iters:int = 10):
realize(lbs)
GlobalCounters.reset()
start = time.time()
realize(_jitted(ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
realize(_jitted(Ops.ADD, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
if i < 0: continue # warm up jit
i_secs = time.time() - start
@@ -68,4 +68,4 @@ def main():
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if __name__ == "__main__":
main()
main()

View File

@@ -1,7 +1,7 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops
from tinygrad.ops import MetaOps, BinaryOps, Ops
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
@@ -31,7 +31,7 @@ N = 128
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0))
b.realize()
return aa, b
@@ -145,14 +145,14 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.REDUCE_MAX)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
n = X.numpy()
X.shard_(devices, shard_axis)
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis),
ReduceOps.REDUCE_MAX: lambda x: x.max(reduce_axis)}[rop]
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis),
Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
fX = f(X)
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
@@ -197,9 +197,9 @@ class TestMultiTensor(unittest.TestCase):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
with Context(RING=2):
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel

View File

@@ -1,6 +1,6 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):

View File

@@ -4,7 +4,7 @@ import unittest
from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, Ops, ReduceOps, print_uops
from tinygrad.ops import UOp, Ops, print_uops
from tinygrad.codegen.kernel import verify_ast
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
@@ -48,7 +48,7 @@ class TestVerifyAST(unittest.TestCase):
def test_no_implicit_broadcasting(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.REDUCE_MAX, (1,)))
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
@@ -62,14 +62,14 @@ class TestVerifyAST(unittest.TestCase):
def test_reduce_store(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
def test_reduce_add_store(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
@@ -84,7 +84,7 @@ class TestVerifyAST(unittest.TestCase):
def test_assert_swizzle(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)

View File

@@ -1,7 +1,7 @@
from collections import defaultdict, deque
from typing import Set, Tuple, List, Dict, DefaultDict
from tinygrad.device import Buffer
from tinygrad.ops import GroupOp, UOp, Ops
from tinygrad.ops import UOp, Ops
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.lazy import LazyBuffer
@@ -19,7 +19,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if tr_next.op in GroupOp.Reduce: return group.setdefault(r)
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
_recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
@@ -31,7 +31,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], ch
if (p:=rc_parents.pop()) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op in GroupOp.Reduce: return {}
if p.op is Ops.REDUCE_AXIS: return {}
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
# search descendants of the reduceop that can cleanly group
descendants: Dict[LazyBuffer, None] = {}
@@ -49,7 +49,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
reduce_for_op: Dict[LazyBuffer, UOp] = {}
reduce_of_const: List[UOp] = []
for r in allbufs:
if r in realizes or r.op not in GroupOp.Reduce: continue
if r in realizes or r.op is not Ops.REDUCE_AXIS: continue
group: Dict[LazyBuffer, None] = {}
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
@@ -77,7 +77,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next.op in GroupOp.Reduce: break
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is Ops.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
@@ -86,7 +86,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
realizes[tr] = None
rbuf = buf_uops[r.buffer]
reduce_for_op.update((tr, rbuf) for tr in group)
if FUSE_ARANGE and r.op is Ops.SUM and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf)
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf)
# fuse double reduces with no other child
if FUSE_CONV_BW:

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu, python_alu, REDUCE_ALU
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -177,19 +177,19 @@ class LazyBuffer(MathTrait):
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(REDUCE_ALU[op], self.dtype), new_shape)
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
# const folding
# TODO: fold this for symbolic?
if self.is_unrealized_unmasked_const() and all_int(self.shape):
if op is ReduceOps.SUM: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
if op is ReduceOps.PROD: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
if op is ReduceOps.REDUCE_MAX: return self.const_with_shape(self.base.arg, new_shape)
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \

View File

@@ -69,8 +69,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
# everything else needs sources
src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs)
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
if buf.op in {Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ret = UOp(buf.op, dtype, src, buf.arg)
elif buf.op is Ops.ASSIGN:
ctx.assigns.add(ubuf)
ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
@@ -82,7 +81,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
ctx.lazybufs[b] = buf
# things for fuse.py
allbufs[buf] = None
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
for x in buf.srcs:
if x.base.realized is None: children[x.base][buf] = None
return ret

View File

@@ -3,7 +3,7 @@ import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import ReduceOps, resolve, sint
from tinygrad.ops import Ops, resolve, sint
from tinygrad.tensor import Function
from tinygrad.engine.lazy import LazyBuffer
@@ -142,13 +142,13 @@ class Where(Function):
class Sum(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, axis)
return x.r(Ops.ADD, axis)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
class Prod(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.PROD, axis)
self.x, self.ret = x, x.r(Ops.MUL, axis)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@@ -156,13 +156,13 @@ class Prod(Function):
class Max(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret, self.axis = x, x.r(ReduceOps.REDUCE_MAX, axis), axis
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
return (max_is_1s/div) * grad_output.expand(self.x.shape)
# ************* movement ops *************
@@ -174,7 +174,7 @@ class Expand(Function):
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:

View File

@@ -3,14 +3,13 @@ from typing import Optional, Tuple, List, Dict
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.dtype import DType
from tinygrad.ops import REDUCE_ALU, Ops, MathTrait
from tinygrad.ops import Ops, MathTrait
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
bop = REDUCE_ALU[op]
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.

View File

@@ -122,9 +122,6 @@ class Ops(FastEnum):
# reduce
REDUCE_AXIS = auto()
# ReduceOps
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
# helper ops
GEP = auto()
VECTORIZE = auto()
@@ -173,7 +170,6 @@ class GroupOp:
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
# meta ops
@@ -187,9 +183,7 @@ class GroupOp:
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
# TODO: remove this?
UnaryOps = BinaryOps = ReduceOps = MetaOps = TernaryOps = Ops
REDUCE_ALU: Dict[Ops, Ops] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
@@ -343,7 +337,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
# *** uop movement ops ***
@@ -815,7 +809,7 @@ spec = PatternMatcher([
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),