Misc rewrite perf improvements (#6500)

* Make UOp a normal class and use __slots__

* Use __slots__ in UPat

* Cache dtypes.{min,max}

* Use faster iterables in ops.py

* extend is a lot faster than nested listcomp

Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com>

---------

Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com>
This commit is contained in:
Tim Becker
2024-09-12 20:31:50 -07:00
committed by GitHub
parent 8c4cab8d6e
commit 7c078191ce
3 changed files with 25 additions and 21 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import itertools, functools
from dataclasses import dataclass, replace
from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
@@ -639,9 +639,9 @@ class Kernel:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
if op.op is UOps.CONST: return replace(op, src=(st_uop,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.CONST: return op.replace(src=(st_uop,))
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
alu_op: BinaryOps = op.arg[0]
@@ -713,7 +713,7 @@ class Kernel:
else:
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
return replace(op, src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
if self.group_for_reduces:
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \
@@ -734,7 +734,7 @@ class Kernel:
arg = (alu_op, axis)
elif op.op is UOps.SINK:
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
return fixup_ast(self.ast)
# **** this is the lowerer ****

View File

@@ -62,10 +62,12 @@ class dtypes:
return tuple(dtypes.as_const(x, dtype) for x in val)
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.lru_cache(None)
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
@functools.lru_cache(None)
def max(dtype:DType):
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
return float("inf") if dtypes.is_float(dtype) else True

View File

@@ -344,12 +344,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
@dataclass(frozen=True, eq=False)
class UOp(MathTrait):
op: UOps
dtype: DType = dtypes.void
src: Tuple[UOp, ...] = tuple()
arg: Any = None
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -619,6 +619,7 @@ def get_location() -> Tuple[str, int]:
def lines(fn) -> List[str]: return open(fn).readlines()
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None,
@@ -626,7 +627,6 @@ class UPat(MathTrait):
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name = arg, name
self.in_src = src
self.src: Any = None
# try all permutations if it's a list
@@ -641,7 +641,7 @@ class UPat(MathTrait):
if custom_early_reject is not None: self.early_reject = custom_early_reject
else:
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
@staticmethod
@@ -691,9 +691,11 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if pat.src is None: return [store]
res: List[Dict[str, UOp]] = []
for vp in pat.src:
new_stores = [store.copy()]
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
res.extend(new_stores)
stores, new_stores = [store.copy()], []
for uu, vv in zip(uop.src, vp):
for s in stores: new_stores.extend(_match(uu, vv, s))
stores, new_stores = new_stores, []
res.extend(stores)
return res
class PatternMatcher:
@@ -709,8 +711,8 @@ class PatternMatcher:
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src])
for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]:
if not early_reject.issubset(ler): continue
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
@@ -727,8 +729,8 @@ class TrackedPattenMatcher(PatternMatcher):
def rewrite(self, uop:UOp) -> Optional[UOp]:
ret = None
ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src])
for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]:
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st
@@ -765,7 +767,7 @@ class RewriteContext:
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
if rn := self.replace.get(n): return rn
replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg)
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
if found := self.nodes.get(replace_source): self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n