mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
delete ReduceOps, only use REDUCE_AXIS (#7667)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from typing import Tuple
|
||||
from tinygrad import Variable
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
|
||||
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from tinygrad import Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.ops import ReduceOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.multi import MultiLazyBuffer, all_reduce
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -14,7 +14,7 @@ def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
|
||||
for lb in x: Device[lb.device].synchronize()
|
||||
|
||||
def test(devs: List[str], N: int, iters:int = 10):
|
||||
def _wrapped(op: ReduceOps, t: Tensor) -> Tensor:
|
||||
def _wrapped(op: Ops, t: Tensor) -> Tensor:
|
||||
return Tensor(MultiLazyBuffer(all_reduce(op, t.lazydata.lbs), 0), device=devs)
|
||||
_jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
|
||||
|
||||
@@ -24,7 +24,7 @@ def test(devs: List[str], N: int, iters:int = 10):
|
||||
realize(lbs)
|
||||
GlobalCounters.reset()
|
||||
start = time.time()
|
||||
realize(_jitted(ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
|
||||
realize(_jitted(Ops.ADD, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
|
||||
if i < 0: continue # warm up jit
|
||||
i_secs = time.time() - start
|
||||
|
||||
@@ -68,4 +68,4 @@ def main():
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops
|
||||
from tinygrad.ops import MetaOps, BinaryOps, Ops
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -31,7 +31,7 @@ N = 128
|
||||
def _test_allreduce(t:Tensor):
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
||||
ts = t.shard(devices_4, 0).realize()
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0))
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
@@ -145,14 +145,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
||||
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.REDUCE_MAX)),
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis),
|
||||
ReduceOps.REDUCE_MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis),
|
||||
Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
fX = f(X)
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
@@ -197,9 +197,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
|
||||
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
||||
with Context(RING=0):
|
||||
a = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
|
||||
a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
with Context(RING=2):
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
diff = a - b
|
||||
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
||||
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps, GroupOp # noqa: F401
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, UnaryOps, GroupOp # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import UOp, Ops, ReduceOps, print_uops
|
||||
from tinygrad.ops import UOp, Ops, print_uops
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
@@ -48,7 +48,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.REDUCE_MAX, (1,)))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
@@ -62,14 +62,14 @@ class TestVerifyAST(unittest.TestCase):
|
||||
def test_reduce_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
def test_assert_swizzle(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict, deque
|
||||
from typing import Set, Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.ops import GroupOp, UOp, Ops
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -19,7 +19,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
||||
return group.setdefault(tr)
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in GroupOp.Reduce: return group.setdefault(r)
|
||||
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
||||
_recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
||||
@@ -31,7 +31,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], ch
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op in GroupOp.Reduce: return {}
|
||||
if p.op is Ops.REDUCE_AXIS: return {}
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: Dict[LazyBuffer, None] = {}
|
||||
@@ -49,7 +49,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
reduce_for_op: Dict[LazyBuffer, UOp] = {}
|
||||
reduce_of_const: List[UOp] = []
|
||||
for r in allbufs:
|
||||
if r in realizes or r.op not in GroupOp.Reduce: continue
|
||||
if r in realizes or r.op is not Ops.REDUCE_AXIS: continue
|
||||
group: Dict[LazyBuffer, None] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
|
||||
# max one reduceop per kernel
|
||||
@@ -77,7 +77,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].size: break
|
||||
st = st + st_childs[0]
|
||||
if not st.contiguous or tr_next.op in GroupOp.Reduce: break
|
||||
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if tr.op is Ops.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
@@ -86,7 +86,7 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
realizes[tr] = None
|
||||
rbuf = buf_uops[r.buffer]
|
||||
reduce_for_op.update((tr, rbuf) for tr in group)
|
||||
if FUSE_ARANGE and r.op is Ops.SUM and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf)
|
||||
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf)
|
||||
|
||||
# fuse double reduces with no other child
|
||||
if FUSE_CONV_BW:
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -177,19 +177,19 @@ class LazyBuffer(MathTrait):
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
|
||||
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = self.st.reduce(axis)
|
||||
# TODO: this logic should move to the scheduler
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(REDUCE_ALU[op], self.dtype), new_shape)
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
|
||||
|
||||
# const folding
|
||||
# TODO: fold this for symbolic?
|
||||
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
||||
if op is ReduceOps.SUM: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is ReduceOps.PROD: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is ReduceOps.REDUCE_MAX: return self.const_with_shape(self.base.arg, new_shape)
|
||||
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
||||
|
||||
@@ -69,8 +69,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
|
||||
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
# everything else needs sources
|
||||
src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs)
|
||||
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
if buf.op in {Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ret = UOp(buf.op, dtype, src, buf.arg)
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
ctx.assigns.add(ubuf)
|
||||
ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
@@ -82,7 +81,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
|
||||
ctx.lazybufs[b] = buf
|
||||
# things for fuse.py
|
||||
allbufs[buf] = None
|
||||
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
for x in buf.srcs:
|
||||
if x.base.realized is None: children[x.base][buf] = None
|
||||
return ret
|
||||
|
||||
@@ -3,7 +3,7 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import ReduceOps, resolve, sint
|
||||
from tinygrad.ops import Ops, resolve, sint
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
|
||||
@@ -142,13 +142,13 @@ class Where(Function):
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, axis)
|
||||
return x.r(Ops.ADD, axis)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
||||
|
||||
class Prod(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.r(ReduceOps.PROD, axis)
|
||||
self.x, self.ret = x, x.r(Ops.MUL, axis)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
@@ -156,13 +156,13 @@ class Prod(Function):
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret, self.axis = x, x.r(ReduceOps.REDUCE_MAX, axis), axis
|
||||
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
||||
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
||||
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
|
||||
return (max_is_1s/div) * grad_output.expand(self.x.shape)
|
||||
|
||||
# ************* movement ops *************
|
||||
@@ -174,7 +174,7 @@ class Expand(Function):
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
|
||||
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Optional, Tuple, List, Dict
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import REDUCE_ALU, Ops, MathTrait
|
||||
from tinygrad.ops import Ops, MathTrait
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
def all_reduce(op: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
||||
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
||||
bop = REDUCE_ALU[op]
|
||||
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
|
||||
@@ -122,9 +122,6 @@ class Ops(FastEnum):
|
||||
# reduce
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# ReduceOps
|
||||
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
|
||||
|
||||
# helper ops
|
||||
GEP = auto()
|
||||
VECTORIZE = auto()
|
||||
@@ -173,7 +170,6 @@ class GroupOp:
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
|
||||
# meta ops
|
||||
@@ -187,9 +183,7 @@ class GroupOp:
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
||||
|
||||
# TODO: remove this?
|
||||
UnaryOps = BinaryOps = ReduceOps = MetaOps = TernaryOps = Ops
|
||||
|
||||
REDUCE_ALU: Dict[Ops, Ops] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
|
||||
UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
@@ -343,7 +337,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop movement ops ***
|
||||
@@ -815,7 +809,7 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
|
||||
Reference in New Issue
Block a user