mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
Misc rewrite perf improvements (#6500)
* Make UOp a normal class and use __slots__
* Use __slots__ in UPat
* Cache dtypes.{min,max}
* Use faster iterables in ops.py
* extend is a lot faster than nested listcomp
Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com>
---------
Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import itertools, functools
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
|
||||
@@ -639,9 +639,9 @@ class Kernel:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
if op.op is UOps.CONST: return replace(op, src=(st_uop,))
|
||||
if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is UOps.CONST: return op.replace(src=(st_uop,))
|
||||
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is UOps.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
||||
alu_op: BinaryOps = op.arg[0]
|
||||
@@ -713,7 +713,7 @@ class Kernel:
|
||||
else:
|
||||
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
|
||||
return replace(op, src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
|
||||
second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \
|
||||
@@ -734,7 +734,7 @@ class Kernel:
|
||||
arg = (alu_op, axis)
|
||||
elif op.op is UOps.SINK:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
||||
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
|
||||
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
|
||||
return fixup_ast(self.ast)
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
@@ -62,10 +62,12 @@ class dtypes:
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -float("inf") if dtypes.is_float(dtype) else False
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def max(dtype:DType):
|
||||
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
|
||||
return float("inf") if dtypes.is_float(dtype) else True
|
||||
|
||||
@@ -344,12 +344,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class UOp(MathTrait):
|
||||
op: UOps
|
||||
dtype: DType = dtypes.void
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
|
||||
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -619,6 +619,7 @@ def get_location() -> Tuple[str, int]:
|
||||
def lines(fn) -> List[str]: return open(fn).readlines()
|
||||
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
@@ -626,7 +627,6 @@ class UPat(MathTrait):
|
||||
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name = arg, name
|
||||
self.in_src = src
|
||||
self.src: Any = None
|
||||
|
||||
# try all permutations if it's a list
|
||||
@@ -641,7 +641,7 @@ class UPat(MathTrait):
|
||||
|
||||
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
||||
else:
|
||||
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
|
||||
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
@staticmethod
|
||||
@@ -691,9 +691,11 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if pat.src is None: return [store]
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in pat.src:
|
||||
new_stores = [store.copy()]
|
||||
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
|
||||
res.extend(new_stores)
|
||||
stores, new_stores = [store.copy()], []
|
||||
for uu, vv in zip(uop.src, vp):
|
||||
for s in stores: new_stores.extend(_match(uu, vv, s))
|
||||
stores, new_stores = new_stores, []
|
||||
res.extend(stores)
|
||||
return res
|
||||
|
||||
class PatternMatcher:
|
||||
@@ -709,8 +711,8 @@ class PatternMatcher:
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src])
|
||||
for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]:
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
@@ -727,8 +729,8 @@ class TrackedPattenMatcher(PatternMatcher):
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src])
|
||||
for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]:
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
@@ -765,7 +767,7 @@ class RewriteContext:
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
if rn := self.replace.get(n): return rn
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg)
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
|
||||
if found := self.nodes.get(replace_source): self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
|
||||
Reference in New Issue
Block a user