Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-12-03 05:55:35 -05:00
9 changed files with 54 additions and 35 deletions

View File

@@ -390,6 +390,11 @@ jobs:
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
--durations=20
- name: Run process replay tests
run: |
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
testmetal:
name: Metal Tests

View File

@@ -3,7 +3,6 @@ from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.search import bufs_from_lin
from tinygrad.helpers import Timing
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -38,8 +37,9 @@ bufs = bufs_from_lin(k)
prg = CompiledRunner(k.to_program())
with Timing("run "):
prg(bufs, var_vals={}, wait=True)
for i in range(10):
speed = prg(bufs, var_vals={}, wait=True)
print(f"kernel time: {speed*1e3:.2f} ms")
# on M1 Max
# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
import numpy as np
from dataclasses import replace
from typing import DefaultDict, Dict, List, Tuple
from tinygrad.ops import END_FOR_UOP, UOp, print_uops
from tinygrad.ops import UOp, print_uops, Ops
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored
@@ -11,6 +11,8 @@ from tinygrad.ops import Variable
from tinygrad.tensor import _to_np_dtype
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
blocks: List[List[UOp]] = [[]]
for u in uops:

View File

@@ -3,7 +3,7 @@ import unittest, math
import numpy as np
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.helpers import CI, DEBUG, getenv, Context, Timing
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
@@ -463,5 +463,19 @@ class TestUPatHelpers(unittest.TestCase):
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
class TestUopsObject(unittest.TestCase):
# LOL, running this test breaks all instances of "4"
"""
@unittest.expectedFailure
def test_immutable(self):
const_4 = UOp.const(dtypes.int, 4)
with self.assertRaises(Exception):
const_4.arg = 5
"""
def test_timing(self):
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
assert len(ret) == 10000
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -113,8 +113,7 @@ class Kernel:
ret.opts, ret.ast = self.opts, self.ast
# things downstream of the AST
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
self.reduceops, self.vars, self.bufs, self.full_buf_index
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
# parameters for optimizations

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import List, Dict, Tuple, Optional
import collections
from dataclasses import dataclass
@@ -19,21 +20,23 @@ class BasicBlock:
ctx: Tuple[UOp, ...]
lst: Tuple[UOp, ...]
end: Optional[UOp] = None
def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
def __repr__(self):
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
def append_to_block(ctx, x:UOp):
def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], x:UOp):
block_ctxs, children = ctx
new_srcs: List[UOp] = []
to_append: List[UOp] = []
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
in_this_block = set(x.arg.lst)
bb: BasicBlock = x.arg
in_this_block = set(bb.lst)
for u in x.src:
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
# if it's a fork or not placed, we don't place it
new_srcs.append(u)
elif (block_ctx:=block_ctxs[u]) == x.arg.ctx:
elif (block_ctx:=block_ctxs[u]) == bb.ctx:
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
new_srcs += list(u.src)
to_append.append(u)
@@ -46,12 +49,12 @@ def append_to_block(ctx, x:UOp):
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
lrng = list(rng)
for r in rng[::-1]:
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
if r not in bb.ctx and r.op is not Ops.BLOCKSTART:
lrng.remove(r)
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
new_srcs.append(new_block)
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(bb.ctx, tuple(to_append)+bb.lst))
make_basic_blocks = PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),

View File

@@ -38,7 +38,6 @@ class LazyBuffer(MathTrait):
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
else:
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
self.buffer.ref(1)
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
else:

View File

@@ -78,6 +78,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None:
buf.buffer.ref(1)
ctx.lazybufs[ubuf] = buf
ctx.allbufs[ubuf] = ret
for x in op.src:

View File

@@ -1,10 +1,9 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
if TYPE_CHECKING:
@@ -194,7 +193,6 @@ def can_pad(u:UOp, edges:Dict[UOp, UOp], visisted:Set[UOp]) -> bool:
visisted.add(u)
return all(can_pad(x.base, edges, visisted) for x in u.src)
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
# With True as the default, this matches the old symbolic behavior
def resolve(x, default:bool=True):
@@ -226,21 +224,24 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
class UOpMetaClass(type):
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
ucache:Dict[Tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
return ret
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
return created
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(MathTrait, metaclass=UOpMetaClass):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
op:Ops
dtype:DType = dtypes.void
src:Tuple[UOp, ...] = tuple()
arg:Any = None
def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
def replace(self, **kwargs) -> UOp:
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
return UOp(*new_args)
@functools.cached_property
@@ -318,7 +319,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -368,8 +369,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return self.const_like(0)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
@@ -568,7 +568,7 @@ def lines(fn) -> List[str]:
with open(fn) as f: return f.readlines()
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
__slots__ = ("op", "dtype", "arg", "name", "src")
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
@@ -720,15 +720,11 @@ def track_rewrites(named=False):
return _decorator
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = {u.op for u in uop.src}
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st