flatten ops (#7523)

* flatten ops

* fix mypy
This commit is contained in:
George Hotz
2024-11-04 18:07:23 +08:00
committed by GitHub
parent 9c3ee64a3e
commit c1585bcc9e
10 changed files with 71 additions and 72 deletions

View File

@@ -622,7 +622,7 @@ class Kernel:
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
alu_op: BinaryOps = op.arg[0]
alu_op: Ops = op.arg[0]
axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]))
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):

View File

@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
@@ -100,7 +100,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
alu_op: BinaryOps = x.arg[0]
alu_op: Ops = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))

View File

@@ -1,7 +1,7 @@
import sys
from collections import defaultdict, deque
from typing import Tuple, List, Dict, DefaultDict
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UOp, UnaryOps, resolve
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UOp, UnaryOps, resolve, GroupOp
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -32,7 +32,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
allbufs[buf] = None
if buf.op is MetaOps.ASSIGN:
assign_targets[(target:=buf.srcs[0])] = buf
@@ -63,7 +63,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if tr_next.op in ReduceOps: return group.setdefault(r)
if tr_next.op in GroupOp.Reduce: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
@@ -75,7 +75,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
if (p:=rc_parents.pop()) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op in ReduceOps: return {}
if p.op in GroupOp.Reduce: return {}
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
# search descendants of the reduceop that can cleanly group
descendants: Dict[LazyBuffer, None] = {}
@@ -101,7 +101,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
reduce_of_const: List[LazyBuffer] = []
for r in allbufs:
if r.op not in ReduceOps or r in realizes: continue
if r.op not in GroupOp.Reduce or r in realizes: continue
group: Dict[LazyBuffer, None] = {}
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
@@ -130,7 +130,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
if not st.contiguous or tr_next.op in GroupOp.Reduce: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:

View File

@@ -2,14 +2,14 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu, python_alu, REDUCE_ALU
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
dtype = to_dtype(dtype)
@@ -25,7 +25,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
class LazyBuffer(MathTrait):
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
self._base: Optional[LazyBuffer] = None
@@ -111,8 +111,7 @@ class LazyBuffer(MathTrait):
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
cast_op: Union[MetaOps, UnaryOps] = \
(MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
cast_op: Union[Ops, Ops] = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp)
@@ -142,7 +141,7 @@ class LazyBuffer(MathTrait):
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer) -> LazyBuffer:
def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
srcs: List[LazyBuffer] = []
for s in (self,)+in_srcs:
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
@@ -159,7 +158,7 @@ class LazyBuffer(MathTrait):
# const folding
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
if op in BinaryOps:
if op in GroupOp.Binary:
x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
@@ -173,13 +172,13 @@ class LazyBuffer(MathTrait):
# *** reduce ops ***
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(REDUCE_ALU[op], self.dtype), new_shape)

View File

@@ -1,8 +1,8 @@
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import BUFOPS, MetaOps, GroupOp, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -69,10 +69,10 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
# everything else needs sources
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in METAOPS: ret = UOp(METAOPS[cast(MetaOps, buf.op)], buf.dtype, (ubuf, *src), buf.arg)
elif buf.op in METAOPS: ret = UOp(METAOPS[buf.op], buf.dtype, (ubuf, *src), buf.arg)
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
else: ret = UOp(Ops.ALU, dtype, src, buf.op)

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
from typing import Optional, Union, Tuple, List, Dict
from typing import Optional, Tuple, List, Dict
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.dtype import DType
from tinygrad.ops import REDUCE_ALU, BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps, MathTrait
from tinygrad.ops import REDUCE_ALU, Ops, MathTrait
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
def all_reduce(op: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
bop = REDUCE_ALU[op]
@@ -94,7 +94,7 @@ class MultiLazyBuffer(MathTrait):
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
# elementwise is simple
def alu(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
msrcs = (self,)+in_srcs
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
@@ -114,7 +114,7 @@ class MultiLazyBuffer(MathTrait):
new_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]

View File

@@ -16,29 +16,9 @@ class FastEnum(IntEnum):
@staticmethod
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(FastEnum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
class BinaryOps(FastEnum):
"""A + A -> A (elementwise)"""
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
class TernaryOps(FastEnum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(FastEnum):
"""A -> B (reduce)"""
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
class MetaOps(FastEnum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); BUFFER_VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
class SimpleMathTrait:
# required to implement
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
# great functions you get!
@@ -115,14 +95,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
def log2(self): return self.alu(UnaryOps.LOG2)
def exp2(self): return self.alu(UnaryOps.EXP2)
# do not preserve f(0) = 0
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
@@ -130,7 +102,7 @@ class Ops(FastEnum):
CONTIGUOUS = auto()
PRELOAD = auto()
# metaops
# MetaOps
COPY = auto()
EMPTY = auto()
BUFFER_VIEW = auto()
@@ -146,14 +118,20 @@ class Ops(FastEnum):
VALID = auto()
SPECIAL = auto()
NOOP = auto()
# reduce
REDUCE = auto()
REDUCE_AXIS = auto()
# ReduceOps
SUM = auto(); PROD = auto(); REDUCE_MAX = auto() # noqa: E702
# helper ops
GEP = auto()
VECTORIZE = auto()
CAST = auto()
BITCAST = auto()
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
# loads before math
LOAD = auto()
@@ -162,6 +140,13 @@ class Ops(FastEnum):
ALU = auto()
WMMA = auto()
# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
# assignment ops
STORE = auto()
ASSIGN = auto()
@@ -183,6 +168,21 @@ class Ops(FastEnum):
VCONST = auto()
CONST = auto()
class GroupOp:
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
Reduce = {Ops.SUM, Ops.PROD, Ops.REDUCE_MAX}
# TODO: remove this
Op = UnaryOps = BinaryOps = ReduceOps = MetaOps = TernaryOps = Ops
# do not preserve f(0) = 0
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
REDUCE_ALU: Dict[Ops, Ops] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.REDUCE_MAX:BinaryOps.MAX}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
BUFOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
def reduce(self, op:Ops, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
# *** uop Variable stuff ***
@@ -456,7 +456,7 @@ def hook_overflow(dv, fxn):
except OverflowError: return dv
return wfxn
python_alu: Dict[Op, Callable] = {
python_alu: Dict[Ops, Callable] = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
@@ -466,7 +466,7 @@ python_alu: Dict[Op, Callable] = {
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
def exec_alu(op:Op, dtype:DType, operands, truncate_output=True):
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
alu = python_alu[op](*operands)
@@ -842,7 +842,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
# *** most of symbolic lives here now ***
def split_uop(x:UOp, sep:BinaryOps):
def split_uop(x:UOp, sep:Ops):
if x.op is Ops.ALU and x.arg is sep:
for s in x.src: yield from split_uop(s, sep)
else: yield x

View File

@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod
from tinygrad.ops import Op, Ops, UOp, flops_mem, sym_infer, sint, Variable
from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable
from tinygrad.dtype import DType
@dataclass(frozen=True)
@@ -83,7 +83,7 @@ class Renderer:
shared_max: int = 32768
tensor_cores: List[TensorCore] = []
extra_matcher: Any = None
code_for_op: Dict[Op, Callable] = {}
code_for_op: Dict[Ops, Callable] = {}
def __reduce__(self): return self.__class__, ()
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -1,7 +1,7 @@
from typing import Dict, Callable, List, Optional
from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, Ops, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp
from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
@@ -52,7 +52,7 @@ class LLVMRenderer(Renderer):
has_local = False
has_shared = False
global_max = None
code_for_op: Dict[Op, Callable] = {
code_for_op: Dict[Ops, Callable] = {
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -14,7 +14,7 @@ def render_val(x, dtype):
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
asm_for_op: Dict[Op, Callable] = {
asm_for_op: Dict[Ops, Callable] = {
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
@@ -32,7 +32,7 @@ asm_for_op: Dict[Op, Callable] = {
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),