mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@@ -390,6 +390,11 @@ jobs:
|
||||
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
|
||||
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
|
||||
--durations=20
|
||||
- name: Run process replay tests
|
||||
run: |
|
||||
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
|
||||
export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
|
||||
cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
testmetal:
|
||||
name: Metal Tests
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.search import bufs_from_lin
|
||||
from tinygrad.helpers import Timing
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
@@ -38,8 +37,9 @@ bufs = bufs_from_lin(k)
|
||||
|
||||
prg = CompiledRunner(k.to_program())
|
||||
|
||||
with Timing("run "):
|
||||
prg(bufs, var_vals={}, wait=True)
|
||||
for i in range(10):
|
||||
speed = prg(bufs, var_vals={}, wait=True)
|
||||
print(f"kernel time: {speed*1e3:.2f} ms")
|
||||
|
||||
# on M1 Max
|
||||
# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229
|
||||
|
||||
4
test/external/fuzz_uops.py
vendored
4
test/external/fuzz_uops.py
vendored
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.ops import END_FOR_UOP, UOp, print_uops
|
||||
from tinygrad.ops import UOp, print_uops, Ops
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
@@ -11,6 +11,8 @@ from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
|
||||
blocks: List[List[UOp]] = [[]]
|
||||
for u in uops:
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context, Timing
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
|
||||
@@ -463,5 +463,19 @@ class TestUPatHelpers(unittest.TestCase):
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
class TestUopsObject(unittest.TestCase):
|
||||
# LOL, running this test breaks all instances of "4"
|
||||
"""
|
||||
@unittest.expectedFailure
|
||||
def test_immutable(self):
|
||||
const_4 = UOp.const(dtypes.int, 4)
|
||||
with self.assertRaises(Exception):
|
||||
const_4.arg = 5
|
||||
"""
|
||||
|
||||
def test_timing(self):
|
||||
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
|
||||
assert len(ret) == 10000
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -113,8 +113,7 @@ class Kernel:
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
|
||||
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
||||
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
|
||||
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
||||
|
||||
# parameters for optimizations
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
@@ -19,21 +20,23 @@ class BasicBlock:
|
||||
ctx: Tuple[UOp, ...]
|
||||
lst: Tuple[UOp, ...]
|
||||
end: Optional[UOp] = None
|
||||
def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
||||
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
|
||||
def append_to_block(ctx, x:UOp):
|
||||
def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], x:UOp):
|
||||
block_ctxs, children = ctx
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
|
||||
in_this_block = set(x.arg.lst)
|
||||
bb: BasicBlock = x.arg
|
||||
in_this_block = set(bb.lst)
|
||||
for u in x.src:
|
||||
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
|
||||
# if it's a fork or not placed, we don't place it
|
||||
new_srcs.append(u)
|
||||
elif (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
elif (block_ctx:=block_ctxs[u]) == bb.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
|
||||
new_srcs += list(u.src)
|
||||
to_append.append(u)
|
||||
@@ -46,12 +49,12 @@ def append_to_block(ctx, x:UOp):
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
||||
if r not in bb.ctx and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
new_srcs.append(new_block)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(bb.ctx, tuple(to_append)+bb.lst))
|
||||
|
||||
make_basic_blocks = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
|
||||
|
||||
@@ -38,7 +38,6 @@ class LazyBuffer(MathTrait):
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
else:
|
||||
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
|
||||
self.buffer.ref(1)
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.forced_realize = False
|
||||
else:
|
||||
|
||||
@@ -78,6 +78,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
if op is not None:
|
||||
buf.buffer.ref(1)
|
||||
ctx.lazybufs[ubuf] = buf
|
||||
ctx.allbufs[ubuf] = ret
|
||||
for x in op.src:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
||||
if TYPE_CHECKING:
|
||||
@@ -194,7 +193,6 @@ def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
|
||||
visisted.add(u)
|
||||
return all(can_pad(x.base, edges, visisted) for x in u.src)
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x, default:bool=True):
|
||||
@@ -226,21 +224,24 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
class UOpMetaClass(type):
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
ucache:Dict[Tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
||||
return ret
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||
return created
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
op:Ops
|
||||
dtype:DType = dtypes.void
|
||||
src:Tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
|
||||
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
||||
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
||||
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
||||
return UOp(*new_args)
|
||||
@functools.cached_property
|
||||
@@ -318,7 +319,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -368,8 +369,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
||||
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return self.const_like(0)
|
||||
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
|
||||
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
@@ -568,7 +568,7 @@ def lines(fn) -> List[str]:
|
||||
with open(fn) as f: return f.readlines()
|
||||
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src")
|
||||
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
|
||||
@@ -720,15 +720,11 @@ def track_rewrites(named=False):
|
||||
return _decorator
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = {u.op for u in uop.src}
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
|
||||
Reference in New Issue
Block a user