mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -622,7 +622,7 @@ class Kernel:
|
||||
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
||||
alu_op: BinaryOps = op.arg[0]
|
||||
alu_op: Ops = op.arg[0]
|
||||
axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
|
||||
if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]))
|
||||
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
@@ -100,7 +100,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UOp, UnaryOps, resolve
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UOp, UnaryOps, resolve, GroupOp
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -32,7 +32,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
|
||||
if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
|
||||
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
allbufs[buf] = None
|
||||
if buf.op is MetaOps.ASSIGN:
|
||||
assign_targets[(target:=buf.srcs[0])] = buf
|
||||
@@ -63,7 +63,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
||||
return group.setdefault(tr)
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps: return group.setdefault(r)
|
||||
if tr_next.op in GroupOp.Reduce: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
||||
@@ -75,7 +75,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
if (p:=rc_parents.pop()) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps: return {}
|
||||
if p.op in GroupOp.Reduce: return {}
|
||||
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: Dict[LazyBuffer, None] = {}
|
||||
@@ -101,7 +101,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
reduce_of_const: List[LazyBuffer] = []
|
||||
for r in allbufs:
|
||||
if r.op not in ReduceOps or r in realizes: continue
|
||||
if r.op not in GroupOp.Reduce or r in realizes: continue
|
||||
|
||||
group: Dict[LazyBuffer, None] = {}
|
||||
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
|
||||
@@ -130,7 +130,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].st.size: break
|
||||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
if not st.contiguous or tr_next.op in GroupOp.Reduce: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
|
||||
@@ -2,14 +2,14 @@ from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
||||
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
|
||||
dtype = to_dtype(dtype)
|
||||
@@ -25,7 +25,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
||||
class LazyBuffer(MathTrait):
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
||||
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
||||
self._base: Optional[LazyBuffer] = None
|
||||
@@ -111,8 +111,7 @@ class LazyBuffer(MathTrait):
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
cast_op: Union[MetaOps, UnaryOps] = \
|
||||
(MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
cast_op: Union[Ops, Ops] = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
|
||||
@@ -142,7 +141,7 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
|
||||
|
||||
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer) -> LazyBuffer:
|
||||
def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
|
||||
srcs: List[LazyBuffer] = []
|
||||
for s in (self,)+in_srcs:
|
||||
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
||||
@@ -159,7 +158,7 @@ class LazyBuffer(MathTrait):
|
||||
# const folding
|
||||
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
||||
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
||||
if op in BinaryOps:
|
||||
if op in GroupOp.Binary:
|
||||
x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
||||
@@ -173,13 +172,13 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = self.st.reduce(axis)
|
||||
# TODO: this logic should move to the scheduler
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(REDUCE_ALU[op], self.dtype), new_shape)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import BUFOPS, MetaOps, GroupOp, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -69,10 +69,10 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
# everything else needs sources
|
||||
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op in METAOPS: ret = UOp(METAOPS[cast(MetaOps, buf.op)], buf.dtype, (ubuf, *src), buf.arg)
|
||||
elif buf.op in METAOPS: ret = UOp(METAOPS[buf.op], buf.dtype, (ubuf, *src), buf.arg)
|
||||
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
|
||||
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Union, Tuple, List, Dict
|
||||
from typing import Optional, Tuple, List, Dict
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import REDUCE_ALU, BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps, MathTrait
|
||||
from tinygrad.ops import REDUCE_ALU, Ops, MathTrait
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
def all_reduce(op: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
||||
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
||||
bop = REDUCE_ALU[op]
|
||||
@@ -94,7 +94,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
||||
|
||||
# elementwise is simple
|
||||
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
||||
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
||||
msrcs = (self,)+in_srcs
|
||||
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
||||
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
||||
@@ -114,7 +114,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and self.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
||||
|
||||
@@ -16,29 +16,9 @@ class FastEnum(IntEnum):
|
||||
@staticmethod
|
||||
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
||||
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(FastEnum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(FastEnum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
class TernaryOps(FastEnum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
class ReduceOps(FastEnum):
|
||||
"""A -> B (reduce)"""
|
||||
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
|
||||
class MetaOps(FastEnum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); BUFFER_VIEW = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
||||
|
||||
class SimpleMathTrait:
|
||||
# required to implement
|
||||
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
|
||||
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
|
||||
def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
@@ -115,14 +95,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
def log2(self): return self.alu(UnaryOps.LOG2)
|
||||
def exp2(self): return self.alu(UnaryOps.EXP2)
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
||||
|
||||
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
@@ -130,7 +102,7 @@ class Ops(FastEnum):
|
||||
CONTIGUOUS = auto()
|
||||
PRELOAD = auto()
|
||||
|
||||
# metaops
|
||||
# MetaOps
|
||||
COPY = auto()
|
||||
EMPTY = auto()
|
||||
BUFFER_VIEW = auto()
|
||||
@@ -146,14 +118,20 @@ class Ops(FastEnum):
|
||||
VALID = auto()
|
||||
SPECIAL = auto()
|
||||
NOOP = auto()
|
||||
|
||||
# reduce
|
||||
REDUCE = auto()
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# ReduceOps
|
||||
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
|
||||
|
||||
# helper ops
|
||||
GEP = auto()
|
||||
VECTORIZE = auto()
|
||||
CAST = auto()
|
||||
BITCAST = auto()
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
|
||||
# loads before math
|
||||
LOAD = auto()
|
||||
@@ -162,6 +140,13 @@ class Ops(FastEnum):
|
||||
ALU = auto()
|
||||
WMMA = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# assignment ops
|
||||
STORE = auto()
|
||||
ASSIGN = auto()
|
||||
@@ -183,6 +168,21 @@ class Ops(FastEnum):
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
|
||||
class GroupOp:
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
|
||||
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
|
||||
|
||||
# TODO: remove this
|
||||
Op = UnaryOps = BinaryOps = ReduceOps = MetaOps = TernaryOps = Ops
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
||||
|
||||
REDUCE_ALU: Dict[Ops, Ops] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
BUFOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
|
||||
def reduce(self, op:Ops, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
@@ -456,7 +456,7 @@ def hook_overflow(dv, fxn):
|
||||
except OverflowError: return dv
|
||||
return wfxn
|
||||
|
||||
python_alu: Dict[Op, Callable] = {
|
||||
python_alu: Dict[Ops, Callable] = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
@@ -466,7 +466,7 @@ python_alu: Dict[Op, Callable] = {
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
||||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def exec_alu(op:Op, dtype:DType, operands, truncate_output=True):
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
if dtype.count > 1:
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
alu = python_alu[op](*operands)
|
||||
@@ -842,7 +842,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def split_uop(x:UOp, sep:BinaryOps):
|
||||
def split_uop(x:UOp, sep:Ops):
|
||||
if x.op is Ops.ALU and x.arg is sep:
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.ops import Op, Ops, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -83,7 +83,7 @@ class Renderer:
|
||||
shared_max: int = 32768
|
||||
tensor_cores: List[TensorCore] = []
|
||||
extra_matcher: Any = None
|
||||
code_for_op: Dict[Op, Callable] = {}
|
||||
code_for_op: Dict[Ops, Callable] = {}
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Callable, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, Ops, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
|
||||
@@ -52,7 +52,7 @@ class LLVMRenderer(Renderer):
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
code_for_op: Dict[Op, Callable] = {
|
||||
code_for_op: Dict[Ops, Callable] = {
|
||||
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -14,7 +14,7 @@ def render_val(x, dtype):
|
||||
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
asm_for_op: Dict[Op, Callable] = {
|
||||
asm_for_op: Dict[Ops, Callable] = {
|
||||
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
@@ -32,7 +32,7 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
|
||||
Reference in New Issue
Block a user