diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index f57e909171..1b14ab6565 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -93,8 +93,8 @@ class BufferXfer(BufferCopy): def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev) class EncDec(Runner): - def __init__(self, encdec:UOp, total_sz:int, device:str): - self.shape, self.pos_var = encdec.arg[0], encdec.variables()[0].expr + def __init__(self, cf:UOp, total_sz:int, device:str): + self.shape, self.pos_var = tuple(s.arg for s in cf.src if s.op is Ops.CONST), cf.variables()[0].expr name = f"enc/dec {total_sz/1e6:7.2f}M, HEVC" if total_sz >= 1e6 else f"enc/dec {total_sz:8d}, HEVC" super().__init__(colored(name, "yellow"), device, Estimates(lds=total_sz, mem=total_sz)) def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): @@ -130,7 +130,7 @@ si_lowerer = PatternMatcher([ (UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))), - (UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: EncDec(encdec, ctx[0].nbytes, ctx[1].device)), + (UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)), ]) @dataclass diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ed9ef96f9b..3a0866bc48 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -71,7 +71,7 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]: base = buf_uops[1].buffer assert isinstance(base, Buffer), "base can't be MultiBuffer" buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize) - ubufs = [b.buffer for b in buf_uops] + ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND] metadata = si.arg.metadata if any(isinstance(x, MultiBuffer) for x in ubufs): assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 56a103ea78..323cbfffab 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, - Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC} + Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL} def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None @@ -18,8 +18,8 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None: if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): - # don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output - if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \ + # don't realize COPY/BUFFER_VIEW when they are the direct source of ASSIGN — the ASSIGN target buffer is the output + if x.op in {Ops.COPY, Ops.BUFFER_VIEW} and x in ctx \ and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD): del ctx[x] # you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce @@ -29,9 +29,9 @@ pm_generate_realize_map = PatternMatcher([ # always realize SINK src (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), # always realize - (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize), + (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize), # realize srcs of these - (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs), + (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), # sometimes realize src of assign (UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src), ]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index be48764bb4..9a4dbd39ba 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -174,7 +174,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # ***************** # 3.5 cleanups -ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP} +ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP} # you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left def cleanup_dead_axes(b:UOp): @@ -518,12 +518,11 @@ def split_store(x:UOp) -> UOp|None: lctx = LocalAddBufferContext() ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True) - # SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops + # SINK requires all buffers on the same device, but COPY/BUFFER_VIEW are cross-device or special hardware ops if ret.op is Ops.STORE: stored = ret.src[1] elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1] else: raise RuntimeError(f"unknown kernel type {ret.op}") if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges) - elif stored.op is Ops.ENCDEC: ret = stored else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts)) kernel = ret.call(*lctx.map.values(), *lctx.vars.keys()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5079fe383d..b2c5461051 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3174,8 +3174,10 @@ class Tensor(OpMixin): the reference frames (`ref_frames`). """ ref_frames = [x.contiguous() for x in ref_frames or []] - assert isinstance(frame_pos, Variable), "frame_pos must be a Variable" - return self.contiguous()._apply_uop(UOp.encdec, state.contiguous(), *ref_frames, extra_args=(frame_pos,), arg=(shape,)) + assert frame_pos.op is Ops.BIND, "frame_pos must be a bound Variable" + srcs = (out:=Tensor.empty(*shape, device=self.device, dtype=self.dtype), self.contiguous(), state.contiguous(), *ref_frames) + fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec") + return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)), device=self.device) # ***** functional nn ops ***** diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 2f8d618076..7831da0f29 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -91,7 +91,7 @@ class Ops(FastEnum): CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto() # buffer ops - BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); ENCDEC = auto() + BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto() # the core 6 movement ops! these only exist in the tensor graph RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 0d1542d722..ca582f01a6 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -229,7 +229,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return () case Ops.BUFFER: return (self.arg,) case Ops.BUFFER_VIEW: return (self.arg[0],) - case Ops.ENCDEC: return self.arg[0] + case Ops.CUSTOM_FUNCTION: return None case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]]) case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) case Ops.PARAM: @@ -567,7 +567,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def mstack(self, *srcs: UOp) -> UOp: return UOp(Ops.MSTACK, self.dtype, (self,)+srcs) @property def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None) - def encdec(self, *src, arg=None): return UOp(Ops.ENCDEC, self.dtype, src=(self,)+src, arg=arg) # *** uop movement ops *** @@ -903,8 +902,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return p def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp: - # TODO: reenable this after ENCDEC is fixed - #assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}" + assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}" return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile)) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) @@ -936,7 +934,7 @@ class CallInfo: def should_resolve_call(c:UOp) -> bool: # don't resolve real kernel calls, sink or program if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False - if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False + if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False if c.arg.precompile: return False return True @@ -1513,7 +1511,7 @@ pm_pyrender_extra = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"), (UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"), - (UPat(Ops.ENCDEC, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.encdec({''.join([str(ctx[s])+', ' for s in x.src[1:]])}arg={x.arg!r})"), + (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda ctx,x: f"UOp(Ops.CUSTOM_FUNCTION, {x.dtype}, src={srcs(ctx, x.src)}, arg={x.arg!r})"), (UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"), # NOTE: range has srcs sometimes after control flow (UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 460e1385b3..9e7dae0997 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -120,11 +120,10 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat.var("x"),), allow_any_len=True, arg=None), lambda root,x: root.dtype == x.dtype and all(u.op is Ops.RANGE for u in root.src[1:])), - # COPY/ALLREDUCE/MULTI/ENCDEC + # COPY/ALLREDUCE/MULTI (UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype), (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)), (UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)), - (UPat(Ops.ENCDEC, name="x"), lambda x: len(x.src) >= 2), # state + inbuffer # REDUCE_AXIS is the reduce in the tensor graph (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), @@ -132,9 +131,10 @@ _tensor_spec = PatternMatcher([ # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), - # allow CALL/PARAM + # allow CALL/PARAM/CUSTOM_FUNCTION (UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype), (UPat(Ops.PARAM), lambda: True), + (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)), # ** for custom kernels ** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index d485bb7ca2..7943a2f836 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -48,7 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", - Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6", + Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6", Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}