diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 41e3eed05a..844c88e68f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -386,8 +386,6 @@ jobs: # run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b - name: Test openpilot LLVM compile run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx - - name: Test openpilot compile4 - run: NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py - name: Run process replay tests uses: ./.github/actions/process-replay diff --git a/examples/openpilot/compile4.py b/examples/openpilot/compile4.py deleted file mode 100644 index e67bc70d94..0000000000 --- a/examples/openpilot/compile4.py +++ /dev/null @@ -1,47 +0,0 @@ -import sys -from tinygrad import Tensor, fetch, GlobalCounters, dtypes -from tinygrad.uop.ops import UOp -from tinygrad.nn.onnx import OnnxRunner -from tinygrad.schedule.rangeify import get_rangeify_map -from tinygrad.engine.schedule import create_schedule_with_vars -from tinygrad.engine.realize import run_schedule - -# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py - -OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" -OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" - -if __name__ == "__main__": - onnx_file = fetch(OPENPILOT_MODEL) - run_onnx = OnnxRunner(onnx_file) - - inputs = run_onnx.get_empty_input_data("npy", dtypes.float32) - out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu') - root = out.uop - targets = [x.uop for x in inputs.values()] - print(targets) - - # TODO: abstract this from gradient? - - # compute the target path (top down) - in_target_path: dict[UOp, bool] = {} - for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src) - independent_set = {} - for u in root.toposort(): - if in_target_path[u]: - for s in u.src: - if not in_target_path[s]: - independent_set[s] = None - independent = UOp.sink(*independent_set.keys()) - kernelized = get_rangeify_map(independent) - independent = independent.substitute(kernelized) - schedule, var_vals = create_schedule_with_vars(independent) - run_schedule(schedule) - - print("**** real ****") - GlobalCounters.reset() - out.uop = root.substitute(kernelized) - out.kernelize() - - # realize - out.realize() diff --git a/test/test_schedule.py b/test/test_schedule.py index 7761a0ca95..e9451257f0 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2160,8 +2160,8 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.ones((4,)).to("CPU") b = Tensor.empty(4, device="CPU") add = a+b - add.kernelize() assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}" + add.kernelize() def test_alu_before_copy(self): buf = Tensor.ones(1).contiguous().realize() diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 9d53ae37e4..e3b173d639 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -4,8 +4,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp # NOTE: unlike before base for a realized tensor is always a BUFFER realized_pattern = UPat(Ops.BUFFER) -# after realization, base tensor uops become RESHAPE(BUFFER) -buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)) def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}" def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat) @@ -57,13 +55,5 @@ class TestTensorUopRepresentation(unittest.TestCase): is_pattern(c, UPat(Ops.ADD)) for s in c.uop.src: is_pattern_uop(s.base, realized_pattern) - def test_empty_buf(self): - a = Tensor.empty(3, 3) - is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) - vi = UOp.variable("i", 1, 3).bind(1) - a = Tensor.empty(3, vi) - is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),)) - self.assertEqual(a.uop.base.buffer.size, 9) - if __name__ == '__main__': unittest.main() diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 01270fe5a8..47c500694f 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -31,10 +31,10 @@ pm_gradient = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient), (UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), - (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)), - (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.shape)) if si!=so)),)), - (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])),)), - (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])),)), + (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), + (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)), + (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), + (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)), (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)), (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 0257adf76d..baa0c4bb5b 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -147,6 +147,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: ending_ranges: dict[UOp, bool] = {} for x in tsink_reverse_toposort: if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue + if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this? ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x]) # if this element has weight and it's ending a range, we (force) realize it diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 1065cd6d2c..6db4c24f25 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -111,9 +111,10 @@ replace_allreduce = PatternMatcher([ # MSELECT on MSTACK is replaced with nothing (UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]), # move shrink before MSTACK - (UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink), + (UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), allow_any_len=True, name="shrink"), mstack_early_shrink), # move MSELECT before movement ops - (UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), name="v"),), name="ms"), lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),))), + (UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), allow_any_len=True, name="v"),), name="ms"), + lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])), ]) # ***** multi functions ***** @@ -203,11 +204,11 @@ def passthrough_multi(root:UOp, multi:UOp): multi_pm = PatternMatcher([ (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi), - (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi), - (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi), - (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi), + (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi), + (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), expand_multi), + (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), pad_multi), + (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), shrink_multi), (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), - (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi), (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi), (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi), diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 14926bf531..90db992bca 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -46,10 +46,11 @@ earliest_rewrites = PatternMatcher([ (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), # merge adjacent RESHAPES, safe because they are not tagged - (UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, name="x"), lambda x,x2: x.replace(src=(x2.src[0],)) if x.tag is None and x2.tag is None else None), + (UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"), + lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None), # remove CONTIGUOUS if the BUFFER is already contiguous - (UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), + (UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), @@ -434,6 +435,7 @@ split_kernels = PatternMatcher([ def tag_uop(ctx:list[UOp], x:UOp): if x.tag is not None: return None + if x.dtype.scalar() == dtypes.index: return None ctx.append(x) return x.replace(tag=(len(ctx)-1,)) add_tags = PatternMatcher([ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ca3ed70f22..fe98b4af67 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -509,37 +509,54 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop movement ops *** - @functools.cached_property - def marg(self): - match self.op: - # TODO: replace these args with srcs - case Ops.RESHAPE | Ops.EXPAND: return tuple([ssimplify(x) for x in self.arg]) - case Ops.PAD | Ops.SHRINK: return tuple([(ssimplify(x), ssimplify(y)) for x,y in self.arg]) - case Ops.PERMUTE | Ops.FLIP: return self.arg - case _: raise RuntimeError(f"{self.op} is not a MovementOp") - @property def base(self) -> UOp: if self.op in GroupOp.Movement: return self.src[0].base if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW return self - def _mop(self, op:Ops, arg, no_reshape_is_no_op:bool=False) -> UOp: - ret = UOp(op, self.dtype, (self,), arg) + # like gep, but might return an integer + def sgep(self, i:int) -> sint: + match self.op: + case Ops.CONST: return self.arg + case Ops.VCONST: return self.arg[i] + case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify()) + case _: raise RuntimeError(f"no sgep on {self.op}") + + @functools.cached_property + def marg(self): + match self.op: + case Ops.RESHAPE | Ops.EXPAND: return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count)) + case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count)) + case Ops.PERMUTE | Ops.FLIP: return self.arg + case _: raise RuntimeError(f"{self.op} is not a MovementOp") + + def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp: + match op: + case Ops.RESHAPE | Ops.EXPAND: src_args = [arg] + case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg)) + case Ops.PERMUTE | Ops.FLIP: src_args = [] + case _: raise RuntimeError(f"{op} is not a MovementOp") + usrcs = [] + for arg in src_args: + if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0))) + elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg)) + else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))) + ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None) # for all movement ops, we check shape property - if ret.shape == self.shape and no_reshape_is_no_op: return self + if ret.shape == self.shape and same_shape_noop: return self return ret # in these four, if the shape doesn't change we can return self - def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=False) - def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=True) - def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, no_reshape_is_no_op=True) - def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, no_reshape_is_no_op=True) - def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, no_reshape_is_no_op=True) + def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False) + def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True) + def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True) + def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True) + def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True) # in these two, we have custom logic to check if they are a no-op - def permute(self, arg:tuple[int, ...]): return UOp(Ops.PERMUTE, self.dtype, (self,), arg) if arg != tuple(range(len(self.shape))) else self - def flip(self, arg:tuple[bool, ...]): return UOp(Ops.FLIP, self.dtype, (self,), arg) if any(arg) and len(arg) == len(self.shape) else self + def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self + def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self # *** uop UNIQUE *** diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index aadacfb76e..78be792353 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -82,13 +82,13 @@ assign_spec = PatternMatcher([ # *** this is the spec of a Tensor in UOp *** tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ - (UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)), - # naturally correct - lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or - # "make things that can't be images not images" can change the buffer dtype - # this is fine as long as it's a realized buffer or const and base dtypes match. - ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \ - and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})), + (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), + (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), + (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), + + # inputs to movement ops + (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), + (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), # Tensor variable bindings (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 34d230d853..e2cce51fcd 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -63,6 +63,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: for u in (toposort:=x.toposort()): # always exclude DEVICE/CONST/UNIQUE if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u) + if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u) for u in toposort: if u in excluded: continue argst = codecs.decode(str(u.arg), "unicode_escape")