From 7c078191ce93c2a9fdb110f9eb4ac59fefc39397 Mon Sep 17 00:00:00 2001 From: Tim Becker Date: Thu, 12 Sep 2024 20:31:50 -0700 Subject: [PATCH] Misc rewrite perf improvements (#6500) * Make UOp a normal class and use __slots__ * Use __slots__ in UPat * Cache dtypes.{min,max} * Use faster iterables in ops.py * extend is a lot faster than nested listcomp Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> --------- Co-authored-by: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> --- tinygrad/codegen/kernel.py | 12 ++++++------ tinygrad/dtype.py | 2 ++ tinygrad/ops.py | 32 +++++++++++++++++--------------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index e182443954..0488924d3a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,6 +1,6 @@ from __future__ import annotations import itertools, functools -from dataclasses import dataclass, replace +from dataclasses import dataclass from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict @@ -639,9 +639,9 @@ class Kernel: # for locals, we use the ShapeTracker that's in the srcs st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() - if op.op is UOps.CONST: return replace(op, src=(st_uop,)) - if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) - return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) + if op.op is UOps.CONST: return op.replace(src=(st_uop,)) + if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) + return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) if op.op is UOps.REDUCE_AXIS: reduce_idx = len(self.bufs) + self.reduceops.index(op)*2 alu_op: BinaryOps = op.arg[0] @@ -713,7 +713,7 @@ class Kernel: else: ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg) new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes) - return replace(op, src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret + return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret if self.group_for_reduces: start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis)) second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \ @@ -734,7 +734,7 @@ class Kernel: arg = (alu_op, axis) elif op.op is UOps.SINK: arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals) - return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg) + return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg) return fixup_ast(self.ast) # **** this is the lowerer **** diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 08aaf7c5e7..ca1df1f006 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -62,10 +62,12 @@ class dtypes: return tuple(dtypes.as_const(x, dtype) for x in val) return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) @staticmethod + @functools.lru_cache(None) def min(dtype:DType): if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod + @functools.lru_cache(None) def max(dtype:DType): if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1 return float("inf") if dtypes.is_float(dtype) else True diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 18774285db..d467653af1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -344,12 +344,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} -@dataclass(frozen=True, eq=False) class UOp(MathTrait): - op: UOps - dtype: DType = dtypes.void - src: Tuple[UOp, ...] = tuple() - arg: Any = None + __slots__ = ["op", "dtype", "src", "arg"] + def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): + self.op, self.dtype, self.src, self.arg = op, dtype, src, arg + def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None): + return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg) @functools.cached_property def st(self) -> Optional[ShapeTracker]: from tinygrad.shape.shapetracker import ShapeTracker @@ -619,6 +619,7 @@ def get_location() -> Tuple[str, int]: def lines(fn) -> List[str]: return open(fn).readlines() class UPat(MathTrait): + __slots__ = ["op", "dtype", "arg", "name", "src"] def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None, name:Optional[str]=None, allow_any_len:bool=False, location=None, @@ -626,7 +627,6 @@ class UPat(MathTrait): self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype self.arg, self.name = arg, name - self.in_src = src self.src: Any = None # try all permutations if it's a list @@ -641,7 +641,7 @@ class UPat(MathTrait): if custom_early_reject is not None: self.early_reject = custom_early_reject else: - upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0]) + upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0]) self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) @staticmethod @@ -691,9 +691,11 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: if pat.src is None: return [store] res: List[Dict[str, UOp]] = [] for vp in pat.src: - new_stores = [store.copy()] - for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)] - res.extend(new_stores) + stores, new_stores = [store.copy()], [] + for uu, vv in zip(uop.src, vp): + for s in stores: new_stores.extend(_match(uu, vv, s)) + stores, new_stores = new_stores, [] + res.extend(stores) return res class PatternMatcher: @@ -709,8 +711,8 @@ class PatternMatcher: def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp) -> Optional[UOp]: - ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src]) - for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) + for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]: if not early_reject.issubset(ler): continue if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match return None @@ -727,8 +729,8 @@ class TrackedPattenMatcher(PatternMatcher): def rewrite(self, uop:UOp) -> Optional[UOp]: ret = None - ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src]) - for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) + for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + self.pdict[(uop.op, None)]: st = time.perf_counter() if not early_reject.issubset(ler): match_stats[p][2] += time.perf_counter()-st @@ -765,7 +767,7 @@ class RewriteContext: self.replace: Dict[UOp, UOp] = {} def rewrite(self, n:UOp) -> UOp: if rn := self.replace.get(n): return rn - replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg) + replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg) if found := self.nodes.get(replace_source): self.replace[n] = found else: x = UOp(*replace_source) if new_src != n.src else n