mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
replace mop arg with vectorized index (#12695)
* replace mop arg with vectorized index * tests passing * better viz * no compile4
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -386,8 +386,6 @@ jobs:
|
||||
# run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
||||
- name: Test openpilot LLVM compile
|
||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot compile4
|
||||
run: NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import sys
|
||||
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.nn.onnx import OnnxRunner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
run_onnx = OnnxRunner(onnx_file)
|
||||
|
||||
inputs = run_onnx.get_empty_input_data("npy", dtypes.float32)
|
||||
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
|
||||
root = out.uop
|
||||
targets = [x.uop for x in inputs.values()]
|
||||
print(targets)
|
||||
|
||||
# TODO: abstract this from gradient?
|
||||
|
||||
# compute the target path (top down)
|
||||
in_target_path: dict[UOp, bool] = {}
|
||||
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
||||
independent_set = {}
|
||||
for u in root.toposort():
|
||||
if in_target_path[u]:
|
||||
for s in u.src:
|
||||
if not in_target_path[s]:
|
||||
independent_set[s] = None
|
||||
independent = UOp.sink(*independent_set.keys())
|
||||
kernelized = get_rangeify_map(independent)
|
||||
independent = independent.substitute(kernelized)
|
||||
schedule, var_vals = create_schedule_with_vars(independent)
|
||||
run_schedule(schedule)
|
||||
|
||||
print("**** real ****")
|
||||
GlobalCounters.reset()
|
||||
out.uop = root.substitute(kernelized)
|
||||
out.kernelize()
|
||||
|
||||
# realize
|
||||
out.realize()
|
||||
@@ -2160,8 +2160,8 @@ class TestCopyFolding(unittest.TestCase):
|
||||
a = Tensor.ones((4,)).to("CPU")
|
||||
b = Tensor.empty(4, device="CPU")
|
||||
add = a+b
|
||||
add.kernelize()
|
||||
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
add.kernelize()
|
||||
|
||||
def test_alu_before_copy(self):
|
||||
buf = Tensor.ones(1).contiguous().realize()
|
||||
|
||||
@@ -4,8 +4,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp
|
||||
|
||||
# NOTE: unlike before base for a realized tensor is always a BUFFER
|
||||
realized_pattern = UPat(Ops.BUFFER)
|
||||
# after realization, base tensor uops become RESHAPE(BUFFER)
|
||||
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
|
||||
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
||||
|
||||
@@ -57,13 +55,5 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
is_pattern(c, UPat(Ops.ADD))
|
||||
for s in c.uop.src: is_pattern_uop(s.base, realized_pattern)
|
||||
|
||||
def test_empty_buf(self):
|
||||
a = Tensor.empty(3, 3)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
vi = UOp.variable("i", 1, 3).bind(1)
|
||||
a = Tensor.empty(3, vi)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),))
|
||||
self.assertEqual(a.uop.base.buffer.size, 9)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -31,10 +31,10 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
||||
(UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.shape)) if si!=so)),)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
|
||||
@@ -147,6 +147,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
ending_ranges: dict[UOp, bool] = {}
|
||||
for x in tsink_reverse_toposort:
|
||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x])
|
||||
|
||||
# if this element has weight and it's ending a range, we (force) realize it
|
||||
|
||||
@@ -111,9 +111,10 @@ replace_allreduce = PatternMatcher([
|
||||
# MSELECT on MSTACK is replaced with nothing
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
|
||||
# move shrink before MSTACK
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), allow_any_len=True, name="shrink"), mstack_early_shrink),
|
||||
# move MSELECT before movement ops
|
||||
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), name="v"),), name="ms"), lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),))),
|
||||
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), allow_any_len=True, name="v"),), name="ms"),
|
||||
lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])),
|
||||
])
|
||||
|
||||
# ***** multi functions *****
|
||||
@@ -203,11 +204,11 @@ def passthrough_multi(root:UOp, multi:UOp):
|
||||
multi_pm = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
||||
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi),
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), expand_multi),
|
||||
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), pad_multi),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), shrink_multi),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
|
||||
@@ -46,10 +46,11 @@ earliest_rewrites = PatternMatcher([
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# merge adjacent RESHAPES, safe because they are not tagged
|
||||
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, name="x"), lambda x,x2: x.replace(src=(x2.src[0],)) if x.tag is None and x2.tag is None else None),
|
||||
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
|
||||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
@@ -434,6 +435,7 @@ split_kernels = PatternMatcher([
|
||||
|
||||
def tag_uop(ctx:list[UOp], x:UOp):
|
||||
if x.tag is not None: return None
|
||||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx.append(x)
|
||||
return x.replace(tag=(len(ctx)-1,))
|
||||
add_tags = PatternMatcher([
|
||||
|
||||
@@ -509,37 +509,54 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@functools.cached_property
|
||||
def marg(self):
|
||||
match self.op:
|
||||
# TODO: replace these args with srcs
|
||||
case Ops.RESHAPE | Ops.EXPAND: return tuple([ssimplify(x) for x in self.arg])
|
||||
case Ops.PAD | Ops.SHRINK: return tuple([(ssimplify(x), ssimplify(y)) for x,y in self.arg])
|
||||
case Ops.PERMUTE | Ops.FLIP: return self.arg
|
||||
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
||||
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
||||
return self
|
||||
|
||||
def _mop(self, op:Ops, arg, no_reshape_is_no_op:bool=False) -> UOp:
|
||||
ret = UOp(op, self.dtype, (self,), arg)
|
||||
# like gep, but might return an integer
|
||||
def sgep(self, i:int) -> sint:
|
||||
match self.op:
|
||||
case Ops.CONST: return self.arg
|
||||
case Ops.VCONST: return self.arg[i]
|
||||
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
|
||||
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||
|
||||
@functools.cached_property
|
||||
def marg(self):
|
||||
match self.op:
|
||||
case Ops.RESHAPE | Ops.EXPAND: return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count))
|
||||
case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
|
||||
case Ops.PERMUTE | Ops.FLIP: return self.arg
|
||||
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
||||
|
||||
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
|
||||
match op:
|
||||
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
|
||||
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
|
||||
case Ops.PERMUTE | Ops.FLIP: src_args = []
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
usrcs = []
|
||||
for arg in src_args:
|
||||
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
|
||||
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
|
||||
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
|
||||
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
|
||||
# for all movement ops, we check shape property
|
||||
if ret.shape == self.shape and no_reshape_is_no_op: return self
|
||||
if ret.shape == self.shape and same_shape_noop: return self
|
||||
return ret
|
||||
|
||||
# in these four, if the shape doesn't change we can return self
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=False)
|
||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=True)
|
||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, no_reshape_is_no_op=True)
|
||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, no_reshape_is_no_op=True)
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, no_reshape_is_no_op=True)
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||
|
||||
# in these two, we have custom logic to check if they are a no-op
|
||||
def permute(self, arg:tuple[int, ...]): return UOp(Ops.PERMUTE, self.dtype, (self,), arg) if arg != tuple(range(len(self.shape))) else self
|
||||
def flip(self, arg:tuple[bool, ...]): return UOp(Ops.FLIP, self.dtype, (self,), arg) if any(arg) and len(arg) == len(self.shape) else self
|
||||
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
||||
|
||||
# *** uop UNIQUE ***
|
||||
|
||||
|
||||
@@ -82,13 +82,13 @@ assign_spec = PatternMatcher([
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
|
||||
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
||||
# naturally correct
|
||||
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
||||
# "make things that can't be images not images" can change the buffer dtype
|
||||
# this is fine as long as it's a realized buffer or const and base dtypes match.
|
||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \
|
||||
and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})),
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
|
||||
@@ -63,6 +63,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
|
||||
Reference in New Issue
Block a user