mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
minor uop speedups [pr] (#8002)
* minor uop cleaner [pr] * free uop creation speed by removing WeakValueDictionary * a lil faster * disable that test * lines * and it doesn't print non hit patterns
This commit is contained in:
4
test/external/fuzz_uops.py
vendored
4
test/external/fuzz_uops.py
vendored
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.ops import END_FOR_UOP, UOp, print_uops
|
||||
from tinygrad.ops import UOp, print_uops, Ops
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
@@ -11,6 +11,8 @@ from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
|
||||
blocks: List[List[UOp]] = [[]]
|
||||
for u in uops:
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context, Timing
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
|
||||
@@ -463,5 +463,19 @@ class TestUPatHelpers(unittest.TestCase):
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
class TestUopsObject(unittest.TestCase):
|
||||
# LOL, running this test breaks all instances of "4"
|
||||
"""
|
||||
@unittest.expectedFailure
|
||||
def test_immutable(self):
|
||||
const_4 = UOp.const(dtypes.int, 4)
|
||||
with self.assertRaises(Exception):
|
||||
const_4.arg = 5
|
||||
"""
|
||||
|
||||
def test_timing(self):
|
||||
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
|
||||
assert len(ret) == 10000
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -113,8 +113,7 @@ class Kernel:
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
|
||||
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
||||
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
|
||||
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
||||
|
||||
# parameters for optimizations
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
||||
if TYPE_CHECKING:
|
||||
@@ -194,7 +193,6 @@ def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
|
||||
visisted.add(u)
|
||||
return all(can_pad(x.base, edges, visisted) for x in u.src)
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x, default:bool=True):
|
||||
@@ -226,21 +224,24 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
class UOpMetaClass(type):
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
ucache:Dict[Tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
||||
return ret
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||
return created
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
op:Ops
|
||||
dtype:DType = dtypes.void
|
||||
src:Tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
|
||||
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
||||
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
||||
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
||||
return UOp(*new_args)
|
||||
@functools.cached_property
|
||||
@@ -568,7 +569,7 @@ def lines(fn) -> List[str]:
|
||||
with open(fn) as f: return f.readlines()
|
||||
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src")
|
||||
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
|
||||
@@ -720,15 +721,11 @@ def track_rewrites(named=False):
|
||||
return _decorator
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = {u.op for u in uop.src}
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
|
||||
Reference in New Issue
Block a user