diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index b16a8391d6..af690dd6bb 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -1,7 +1,8 @@ from extra.models.resnet import ResNet50 from tinygrad import Tensor -from tinygrad.helpers import Profiling, Timing, getenv -from tinygrad.engine.realize import lower_schedule +from tinygrad.helpers import Profiling, Timing, getenv, dedup +from tinygrad.ops import MetaOps +from tinygrad.codegen.kernel import Kernel if __name__ == "__main__": mdl = ResNet50() @@ -19,10 +20,29 @@ if __name__ == "__main__": with Timing("***** model schedule in "): sched = out.schedule() - # snakeviz /tmp/schedule.prof + asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.SINK]) + uops = [] + with Profiling(PROFILE): + with Timing("***** model uops in "): + for ast in asts: + k = Kernel(ast) + k.hand_coded_optimizations() + k.linearize() + uops.append((k.name, k.uops)) + with Profiling(PROFILE, fn="/tmp/schedule.prof"): - with Timing("***** model lower in "): - eis = list(lower_schedule(sched)) + with Timing("***** model linearize in "): + for _,u in uops: u.linearize() + + #renderer = Device[Device.DEFAULT].renderer + #with Profiling(PROFILE, fn="/tmp/schedule.prof"): + # with Timing("***** model render in "): + # for n,u in uops: renderer.render(n, u) + + # snakeviz /tmp/schedule.prof + #with Profiling(PROFILE, fn="/tmp/schedule.prof"): + # with Timing("***** model lower in "): + # eis = list(lower_schedule(sched)) # random makes this slow #with Profiling(PROFILE): diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 8d164ca240..b601516481 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -1,7 +1,7 @@ import unittest from test.helpers import TestUOps from tinygrad.dtype import dtypes -from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps +from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.codegen.uops import UOps, UOp from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat, _match @@ -47,6 +47,7 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c5), None) + @unittest.skip("this is not supported any more") def test_arg_set(self): matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)]) y1 = UOp(UOps.CONST, dtypes.int, arg=1) @@ -123,14 +124,14 @@ class TestPatternMatcher(TestUOps): self.assertEqual(matcher.rewrite(c4), None) def test_allow_len(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c3 = UOp(UOps.CONST, dtypes.float, arg=3.0) - c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG) + #c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG) c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC) - self.assertEqual(matcher.rewrite(c4), c4) + #self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), None) self.assertEqual(matcher.rewrite(c6), c6) diff --git a/test/test_print_tree.py b/test/test_print_tree.py index 150ac936fd..623321ba48 100644 --- a/test/test_print_tree.py +++ b/test/test_print_tree.py @@ -3,9 +3,7 @@ import unittest from tinygrad.engine.graph import print_tree from tinygrad import Tensor, dtypes -from tinygrad.codegen.uops import UOps, UOp -from tinygrad.codegen.uopgraph import UPat -from tinygrad.ops import BinaryOps +from tinygrad.codegen.uops import UOp import sys, io @@ -43,6 +41,7 @@ ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19) 4 ┗━┳ UOps.ALU UnaryOps.NEG\n\ 5 ┗━━ UOps.CONST 2\n' + """ x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int))) assert self._capture_print(lambda: print_tree(x)) == '\ 0 ━━ UOps.ALU : dtypes.int [, ] BinaryOps.ADD None\n' @@ -62,6 +61,7 @@ ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19) 9 ┃ ┗━━ None None\n\ 10 ┗━┳ UOps.GEP 3\n\ 11 ┗━━ None None\n' + """ if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index a1c230d4ca..af0b79dc5f 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar, TYPE_CHECKING +from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TYPE_CHECKING import functools, itertools, heapq, math from collections import defaultdict -from dataclasses import dataclass, field from tinygrad.dtype import dtypes, DType, PtrDType, ImageDType from tinygrad.shape.symbolic import Variable from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu @@ -15,15 +14,25 @@ if TYPE_CHECKING: # *** simplification logic *** -@dataclass(frozen=True) class UPat: - op: Optional[Union[UOps, Set[UOps]]] = None - arg: Any = None - src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None - name: Optional[str] = None - dtype: Optional[Union[DType, Set[DType]]] = None - allow_len: Set[int] = field(default_factory=set) - allow_any_len: bool = False + def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, + name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False): + self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) + self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) + self.arg = arg + self.src: Any = None + if isinstance(src, list): + # try all permutations if it's a list + self.src = list(itertools.permutations(src)) + elif isinstance(src, tuple): + # only one if it's a tuple + self.src = [src] + elif isinstance(src, UPat): + # repeat if it's a UPat + self.src = [itertools.repeat(src)] + allow_any_len = True + self.name: Optional[str] = name + self.allow_any_len: bool = allow_any_len @staticmethod def compile(u: UOp, name:Optional[str]=None) -> UPat: @@ -31,21 +40,15 @@ class UPat: return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name)) -T = TypeVar("T") -def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1 - def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: - if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return [] - if pat.arg is not None and __unmatch(pat.arg, uop.arg): return [] - if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return [] - if pat.op is not None and __unmatch(pat.op, uop.op): return [] + if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \ + (pat.dtype is not None and uop.dtype is not None and uop.dtype not in pat.dtype) or \ + (pat.arg is not None and pat.arg != uop.arg) or \ + (pat.op is not None and uop.op not in pat.op): return [] if pat.src is None: return [store] - # only one if it's a tuple - # try all permutations if it's a list - # repeat if it's a UPat res: List[Dict[str, UOp]] = [] - for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]): - if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return [] + for vp in pat.src: + if not pat.allow_any_len and len(uop.src) != len(vp): return [] new_stores = [store.copy()] for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)] res.extend(new_stores) @@ -59,10 +62,7 @@ class PatternMatcher: for p,fxn in self.patterns: if isinstance(p, UOp): p = UPat.compile(p) assert p.op is not None - if isinstance(p.op, set): - for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) - else: - self.pdict[(p.op, p.arg)].append((p, fxn)) + for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) def rewrite(self, uop:UOp) -> Optional[UOp]: for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 493280ef3d..0ce3d41922 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -227,13 +227,16 @@ class PTXRenderer(Renderer): return self.render_kernel(kernel, name, bufs, c.items()) +shiftable_consts = set([2**i for i in range(64)]) ptx_matcher = PatternMatcher([ (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]), - lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)), + src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]), + lambda root, mul, const: UOp(UOps.ALU, root.dtype, + (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]), - lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)), + src=[UPat(UOps.CONST, name="const"), UPat(name="div")]), + lambda root, div, const: UOp(UOps.ALU, root.dtype, + (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None), (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"), lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), @@ -255,17 +258,17 @@ ptx_matcher = PatternMatcher([ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))), lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)), # ptr_ar (load/store) - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), lambda root, alu, const: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(UOps.CONST, name="const"))), lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64), UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(name="alu"))), # no const here lambda root, alu: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),