replace mop arg with vectorized index (#12695)

* replace mop arg with vectorized index

* tests passing

* better viz

* no compile4
This commit is contained in:
George Hotz
2025-10-15 20:50:06 +08:00
committed by GitHub
parent 9ec4c06d7d
commit 612e3d6143
11 changed files with 61 additions and 98 deletions

View File

@@ -386,8 +386,6 @@ jobs:
# run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
- name: Test openpilot LLVM compile
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot compile4
run: NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
- name: Run process replay tests
uses: ./.github/actions/process-replay

View File

@@ -1,47 +0,0 @@
import sys
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.nn.onnx import OnnxRunner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import run_schedule
# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
if __name__ == "__main__":
onnx_file = fetch(OPENPILOT_MODEL)
run_onnx = OnnxRunner(onnx_file)
inputs = run_onnx.get_empty_input_data("npy", dtypes.float32)
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
root = out.uop
targets = [x.uop for x in inputs.values()]
print(targets)
# TODO: abstract this from gradient?
# compute the target path (top down)
in_target_path: dict[UOp, bool] = {}
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
independent_set = {}
for u in root.toposort():
if in_target_path[u]:
for s in u.src:
if not in_target_path[s]:
independent_set[s] = None
independent = UOp.sink(*independent_set.keys())
kernelized = get_rangeify_map(independent)
independent = independent.substitute(kernelized)
schedule, var_vals = create_schedule_with_vars(independent)
run_schedule(schedule)
print("**** real ****")
GlobalCounters.reset()
out.uop = root.substitute(kernelized)
out.kernelize()
# realize
out.realize()

View File

@@ -2160,8 +2160,8 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.ones((4,)).to("CPU")
b = Tensor.empty(4, device="CPU")
add = a+b
add.kernelize()
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
add.kernelize()
def test_alu_before_copy(self):
buf = Tensor.ones(1).contiguous().realize()

View File

@@ -4,8 +4,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp
# NOTE: unlike before base for a realized tensor is always a BUFFER
realized_pattern = UPat(Ops.BUFFER)
# after realization, base tensor uops become RESHAPE(BUFFER)
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
@@ -57,13 +55,5 @@ class TestTensorUopRepresentation(unittest.TestCase):
is_pattern(c, UPat(Ops.ADD))
for s in c.uop.src: is_pattern_uop(s.base, realized_pattern)
def test_empty_buf(self):
a = Tensor.empty(3, 3)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
vi = UOp.variable("i", 1, 3).bind(1)
a = Tensor.empty(3, vi)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),))
self.assertEqual(a.uop.base.buffer.size, 9)
if __name__ == '__main__':
unittest.main()

View File

@@ -31,10 +31,10 @@ pm_gradient = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
(UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.shape)) if si!=so)),)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),

View File

@@ -147,6 +147,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
ending_ranges: dict[UOp, bool] = {}
for x in tsink_reverse_toposort:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x])
# if this element has weight and it's ending a range, we (force) realize it

View File

@@ -111,9 +111,10 @@ replace_allreduce = PatternMatcher([
# MSELECT on MSTACK is replaced with nothing
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
# move shrink before MSTACK
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), allow_any_len=True, name="shrink"), mstack_early_shrink),
# move MSELECT before movement ops
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), name="v"),), name="ms"), lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),))),
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), allow_any_len=True, name="v"),), name="ms"),
lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])),
])
# ***** multi functions *****
@@ -203,11 +204,11 @@ def passthrough_multi(root:UOp, multi:UOp):
multi_pm = PatternMatcher([
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), pad_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), shrink_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),

View File

@@ -46,10 +46,11 @@ earliest_rewrites = PatternMatcher([
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
# merge adjacent RESHAPES, safe because they are not tagged
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, name="x"), lambda x,x2: x.replace(src=(x2.src[0],)) if x.tag is None and x2.tag is None else None),
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@@ -434,6 +435,7 @@ split_kernels = PatternMatcher([
def tag_uop(ctx:list[UOp], x:UOp):
if x.tag is not None: return None
if x.dtype.scalar() == dtypes.index: return None
ctx.append(x)
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([

View File

@@ -509,37 +509,54 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop movement ops ***
@functools.cached_property
def marg(self):
match self.op:
# TODO: replace these args with srcs
case Ops.RESHAPE | Ops.EXPAND: return tuple([ssimplify(x) for x in self.arg])
case Ops.PAD | Ops.SHRINK: return tuple([(ssimplify(x), ssimplify(y)) for x,y in self.arg])
case Ops.PERMUTE | Ops.FLIP: return self.arg
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
@property
def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
return self
def _mop(self, op:Ops, arg, no_reshape_is_no_op:bool=False) -> UOp:
ret = UOp(op, self.dtype, (self,), arg)
# like gep, but might return an integer
def sgep(self, i:int) -> sint:
match self.op:
case Ops.CONST: return self.arg
case Ops.VCONST: return self.arg[i]
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
case _: raise RuntimeError(f"no sgep on {self.op}")
@functools.cached_property
def marg(self):
match self.op:
case Ops.RESHAPE | Ops.EXPAND: return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count))
case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.PERMUTE | Ops.FLIP: return self.arg
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
match op:
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
case Ops.PERMUTE | Ops.FLIP: src_args = []
case _: raise RuntimeError(f"{op} is not a MovementOp")
usrcs = []
for arg in src_args:
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
# for all movement ops, we check shape property
if ret.shape == self.shape and no_reshape_is_no_op: return self
if ret.shape == self.shape and same_shape_noop: return self
return ret
# in these four, if the shape doesn't change we can return self
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=False)
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=True)
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, no_reshape_is_no_op=True)
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, no_reshape_is_no_op=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, no_reshape_is_no_op=True)
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
# in these two, we have custom logic to check if they are a no-op
def permute(self, arg:tuple[int, ...]): return UOp(Ops.PERMUTE, self.dtype, (self,), arg) if arg != tuple(range(len(self.shape))) else self
def flip(self, arg:tuple[bool, ...]): return UOp(Ops.FLIP, self.dtype, (self,), arg) if any(arg) and len(arg) == len(self.shape) else self
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
# *** uop UNIQUE ***

View File

@@ -82,13 +82,13 @@ assign_spec = PatternMatcher([
# *** this is the spec of a Tensor in UOp ***
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
# naturally correct
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
# "make things that can't be images not images" can change the buffer dtype
# this is fine as long as it's a realized buffer or const and base dtypes match.
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \
and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})),
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),

View File

@@ -63,6 +63,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")