minor uop speedups [pr] (#8002)

* minor uop cleaner [pr]

* free uop creation speed by removing WeakValueDictionary

* a lil faster

* disable that test

* lines

* and it doesn't print non hit patterns
This commit is contained in:
George Hotz
2024-12-03 17:04:48 +08:00
committed by GitHub
parent 1028b34a20
commit b8bf5b2787
4 changed files with 35 additions and 23 deletions

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
import numpy as np
from dataclasses import replace
from typing import DefaultDict, Dict, List, Tuple
from tinygrad.ops import END_FOR_UOP, UOp, print_uops
from tinygrad.ops import UOp, print_uops, Ops
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored
@@ -11,6 +11,8 @@ from tinygrad.ops import Variable
from tinygrad.tensor import _to_np_dtype
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
blocks: List[List[UOp]] = [[]]
for u in uops:

View File

@@ -3,7 +3,7 @@ import unittest, math
import numpy as np
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.helpers import CI, DEBUG, getenv, Context, Timing
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
@@ -463,5 +463,19 @@ class TestUPatHelpers(unittest.TestCase):
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
class TestUopsObject(unittest.TestCase):
# LOL, running this test breaks all instances of "4"
"""
@unittest.expectedFailure
def test_immutable(self):
const_4 = UOp.const(dtypes.int, 4)
with self.assertRaises(Exception):
const_4.arg = 5
"""
def test_timing(self):
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
assert len(ret) == 10000
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -113,8 +113,7 @@ class Kernel:
ret.opts, ret.ast = self.opts, self.ast
# things downstream of the AST
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
self.reduceops, self.vars, self.bufs, self.full_buf_index
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
# parameters for optimizations

View File

@@ -1,10 +1,9 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
if TYPE_CHECKING:
@@ -194,7 +193,6 @@ def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
visisted.add(u)
return all(can_pad(x.base, edges, visisted) for x in u.src)
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
# With True as the default, this matches the old symbolic behavior
def resolve(x, default:bool=True):
@@ -226,21 +224,24 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
class UOpMetaClass(type):
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
ucache:Dict[Tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
return ret
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
return created
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(MathTrait, metaclass=UOpMetaClass):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
op:Ops
dtype:DType = dtypes.void
src:Tuple[UOp, ...] = tuple()
arg:Any = None
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
def replace(self, **kwargs) -> UOp:
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
return UOp(*new_args)
@functools.cached_property
@@ -568,7 +569,7 @@ def lines(fn) -> List[str]:
with open(fn) as f: return f.readlines()
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
__slots__ = ("op", "dtype", "arg", "name", "src")
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
@@ -720,15 +721,11 @@ def track_rewrites(named=False):
return _decorator
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = {u.op for u in uop.src}
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st