mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
replace mop arg with vectorized index (#12695)
* replace mop arg with vectorized index * tests passing * better viz * no compile4
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -386,8 +386,6 @@ jobs:
|
|||||||
# run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
# run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
||||||
- name: Test openpilot LLVM compile
|
- name: Test openpilot LLVM compile
|
||||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||||
- name: Test openpilot compile4
|
|
||||||
run: NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
|
||||||
- name: Run process replay tests
|
- name: Run process replay tests
|
||||||
uses: ./.github/actions/process-replay
|
uses: ./.github/actions/process-replay
|
||||||
|
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
import sys
|
|
||||||
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
|
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
from tinygrad.nn.onnx import OnnxRunner
|
|
||||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
|
||||||
from tinygrad.engine.realize import run_schedule
|
|
||||||
|
|
||||||
# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
|
||||||
|
|
||||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
|
||||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
onnx_file = fetch(OPENPILOT_MODEL)
|
|
||||||
run_onnx = OnnxRunner(onnx_file)
|
|
||||||
|
|
||||||
inputs = run_onnx.get_empty_input_data("npy", dtypes.float32)
|
|
||||||
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
|
|
||||||
root = out.uop
|
|
||||||
targets = [x.uop for x in inputs.values()]
|
|
||||||
print(targets)
|
|
||||||
|
|
||||||
# TODO: abstract this from gradient?
|
|
||||||
|
|
||||||
# compute the target path (top down)
|
|
||||||
in_target_path: dict[UOp, bool] = {}
|
|
||||||
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
|
||||||
independent_set = {}
|
|
||||||
for u in root.toposort():
|
|
||||||
if in_target_path[u]:
|
|
||||||
for s in u.src:
|
|
||||||
if not in_target_path[s]:
|
|
||||||
independent_set[s] = None
|
|
||||||
independent = UOp.sink(*independent_set.keys())
|
|
||||||
kernelized = get_rangeify_map(independent)
|
|
||||||
independent = independent.substitute(kernelized)
|
|
||||||
schedule, var_vals = create_schedule_with_vars(independent)
|
|
||||||
run_schedule(schedule)
|
|
||||||
|
|
||||||
print("**** real ****")
|
|
||||||
GlobalCounters.reset()
|
|
||||||
out.uop = root.substitute(kernelized)
|
|
||||||
out.kernelize()
|
|
||||||
|
|
||||||
# realize
|
|
||||||
out.realize()
|
|
||||||
@@ -2160,8 +2160,8 @@ class TestCopyFolding(unittest.TestCase):
|
|||||||
a = Tensor.ones((4,)).to("CPU")
|
a = Tensor.ones((4,)).to("CPU")
|
||||||
b = Tensor.empty(4, device="CPU")
|
b = Tensor.empty(4, device="CPU")
|
||||||
add = a+b
|
add = a+b
|
||||||
add.kernelize()
|
|
||||||
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||||
|
add.kernelize()
|
||||||
|
|
||||||
def test_alu_before_copy(self):
|
def test_alu_before_copy(self):
|
||||||
buf = Tensor.ones(1).contiguous().realize()
|
buf = Tensor.ones(1).contiguous().realize()
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp
|
|||||||
|
|
||||||
# NOTE: unlike before base for a realized tensor is always a BUFFER
|
# NOTE: unlike before base for a realized tensor is always a BUFFER
|
||||||
realized_pattern = UPat(Ops.BUFFER)
|
realized_pattern = UPat(Ops.BUFFER)
|
||||||
# after realization, base tensor uops become RESHAPE(BUFFER)
|
|
||||||
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
|
|
||||||
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
||||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
||||||
|
|
||||||
@@ -57,13 +55,5 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
|||||||
is_pattern(c, UPat(Ops.ADD))
|
is_pattern(c, UPat(Ops.ADD))
|
||||||
for s in c.uop.src: is_pattern_uop(s.base, realized_pattern)
|
for s in c.uop.src: is_pattern_uop(s.base, realized_pattern)
|
||||||
|
|
||||||
def test_empty_buf(self):
|
|
||||||
a = Tensor.empty(3, 3)
|
|
||||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
|
||||||
vi = UOp.variable("i", 1, 3).bind(1)
|
|
||||||
a = Tensor.empty(3, vi)
|
|
||||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),))
|
|
||||||
self.assertEqual(a.uop.base.buffer.size, 9)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ pm_gradient = PatternMatcher([
|
|||||||
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
||||||
(UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
|
(UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
|
||||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.shape)) if si!=so)),)),
|
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)),
|
||||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
|
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])),)),
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
|
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
|
||||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||||||
ending_ranges: dict[UOp, bool] = {}
|
ending_ranges: dict[UOp, bool] = {}
|
||||||
for x in tsink_reverse_toposort:
|
for x in tsink_reverse_toposort:
|
||||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||||
|
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||||
ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x])
|
ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x])
|
||||||
|
|
||||||
# if this element has weight and it's ending a range, we (force) realize it
|
# if this element has weight and it's ending a range, we (force) realize it
|
||||||
|
|||||||
@@ -111,9 +111,10 @@ replace_allreduce = PatternMatcher([
|
|||||||
# MSELECT on MSTACK is replaced with nothing
|
# MSELECT on MSTACK is replaced with nothing
|
||||||
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
|
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
|
||||||
# move shrink before MSTACK
|
# move shrink before MSTACK
|
||||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
|
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), allow_any_len=True, name="shrink"), mstack_early_shrink),
|
||||||
# move MSELECT before movement ops
|
# move MSELECT before movement ops
|
||||||
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), name="v"),), name="ms"), lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),))),
|
(UPat(Ops.MSELECT, src=(UPat(GroupOp.Movement, src=(UPat.var("s"),), allow_any_len=True, name="v"),), name="ms"),
|
||||||
|
lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])),
|
||||||
])
|
])
|
||||||
|
|
||||||
# ***** multi functions *****
|
# ***** multi functions *****
|
||||||
@@ -203,11 +204,11 @@ def passthrough_multi(root:UOp, multi:UOp):
|
|||||||
multi_pm = PatternMatcher([
|
multi_pm = PatternMatcher([
|
||||||
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
||||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
||||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi),
|
||||||
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), expand_multi),
|
||||||
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), pad_multi),
|
||||||
|
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), UPat(), UPat()), name="root"), shrink_multi),
|
||||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
|
||||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||||
|
|||||||
@@ -46,10 +46,11 @@ earliest_rewrites = PatternMatcher([
|
|||||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||||
|
|
||||||
# merge adjacent RESHAPES, safe because they are not tagged
|
# merge adjacent RESHAPES, safe because they are not tagged
|
||||||
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, name="x"), lambda x,x2: x.replace(src=(x2.src[0],)) if x.tag is None and x2.tag is None else None),
|
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
|
||||||
|
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||||
|
|
||||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||||
|
|
||||||
# split_reduceop
|
# split_reduceop
|
||||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||||
@@ -434,6 +435,7 @@ split_kernels = PatternMatcher([
|
|||||||
|
|
||||||
def tag_uop(ctx:list[UOp], x:UOp):
|
def tag_uop(ctx:list[UOp], x:UOp):
|
||||||
if x.tag is not None: return None
|
if x.tag is not None: return None
|
||||||
|
if x.dtype.scalar() == dtypes.index: return None
|
||||||
ctx.append(x)
|
ctx.append(x)
|
||||||
return x.replace(tag=(len(ctx)-1,))
|
return x.replace(tag=(len(ctx)-1,))
|
||||||
add_tags = PatternMatcher([
|
add_tags = PatternMatcher([
|
||||||
|
|||||||
@@ -509,37 +509,54 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||||||
|
|
||||||
# *** uop movement ops ***
|
# *** uop movement ops ***
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def marg(self):
|
|
||||||
match self.op:
|
|
||||||
# TODO: replace these args with srcs
|
|
||||||
case Ops.RESHAPE | Ops.EXPAND: return tuple([ssimplify(x) for x in self.arg])
|
|
||||||
case Ops.PAD | Ops.SHRINK: return tuple([(ssimplify(x), ssimplify(y)) for x,y in self.arg])
|
|
||||||
case Ops.PERMUTE | Ops.FLIP: return self.arg
|
|
||||||
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base(self) -> UOp:
|
def base(self) -> UOp:
|
||||||
if self.op in GroupOp.Movement: return self.src[0].base
|
if self.op in GroupOp.Movement: return self.src[0].base
|
||||||
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _mop(self, op:Ops, arg, no_reshape_is_no_op:bool=False) -> UOp:
|
# like gep, but might return an integer
|
||||||
ret = UOp(op, self.dtype, (self,), arg)
|
def sgep(self, i:int) -> sint:
|
||||||
|
match self.op:
|
||||||
|
case Ops.CONST: return self.arg
|
||||||
|
case Ops.VCONST: return self.arg[i]
|
||||||
|
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
|
||||||
|
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def marg(self):
|
||||||
|
match self.op:
|
||||||
|
case Ops.RESHAPE | Ops.EXPAND: return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count))
|
||||||
|
case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
|
||||||
|
case Ops.PERMUTE | Ops.FLIP: return self.arg
|
||||||
|
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
||||||
|
|
||||||
|
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
|
||||||
|
match op:
|
||||||
|
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
|
||||||
|
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
|
||||||
|
case Ops.PERMUTE | Ops.FLIP: src_args = []
|
||||||
|
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||||
|
usrcs = []
|
||||||
|
for arg in src_args:
|
||||||
|
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
|
||||||
|
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
|
||||||
|
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
|
||||||
|
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
|
||||||
# for all movement ops, we check shape property
|
# for all movement ops, we check shape property
|
||||||
if ret.shape == self.shape and no_reshape_is_no_op: return self
|
if ret.shape == self.shape and same_shape_noop: return self
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# in these four, if the shape doesn't change we can return self
|
# in these four, if the shape doesn't change we can return self
|
||||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=False)
|
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=True)
|
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, no_reshape_is_no_op=True)
|
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, no_reshape_is_no_op=True)
|
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, no_reshape_is_no_op=True)
|
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||||
|
|
||||||
# in these two, we have custom logic to check if they are a no-op
|
# in these two, we have custom logic to check if they are a no-op
|
||||||
def permute(self, arg:tuple[int, ...]): return UOp(Ops.PERMUTE, self.dtype, (self,), arg) if arg != tuple(range(len(self.shape))) else self
|
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||||
def flip(self, arg:tuple[bool, ...]): return UOp(Ops.FLIP, self.dtype, (self,), arg) if any(arg) and len(arg) == len(self.shape) else self
|
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
||||||
|
|
||||||
# *** uop UNIQUE ***
|
# *** uop UNIQUE ***
|
||||||
|
|
||||||
|
|||||||
@@ -82,13 +82,13 @@ assign_spec = PatternMatcher([
|
|||||||
# *** this is the spec of a Tensor in UOp ***
|
# *** this is the spec of a Tensor in UOp ***
|
||||||
|
|
||||||
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||||
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||||
# naturally correct
|
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||||
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||||
# "make things that can't be images not images" can change the buffer dtype
|
|
||||||
# this is fine as long as it's a realized buffer or const and base dtypes match.
|
# inputs to movement ops
|
||||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \
|
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||||
and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})),
|
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||||
|
|
||||||
# Tensor variable bindings
|
# Tensor variable bindings
|
||||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||||||
for u in (toposort:=x.toposort()):
|
for u in (toposort:=x.toposort()):
|
||||||
# always exclude DEVICE/CONST/UNIQUE
|
# always exclude DEVICE/CONST/UNIQUE
|
||||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||||
|
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||||
for u in toposort:
|
for u in toposort:
|
||||||
if u in excluded: continue
|
if u in excluded: continue
|
||||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||||
|
|||||||
Reference in New Issue
Block a user