mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -93,8 +93,8 @@ class BufferXfer(BufferCopy):
|
||||
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
|
||||
|
||||
class EncDec(Runner):
|
||||
def __init__(self, encdec:UOp, total_sz:int, device:str):
|
||||
self.shape, self.pos_var = encdec.arg[0], encdec.variables()[0].expr
|
||||
def __init__(self, cf:UOp, total_sz:int, device:str):
|
||||
self.shape, self.pos_var = tuple(s.arg for s in cf.src if s.op is Ops.CONST), cf.variables()[0].expr
|
||||
name = f"enc/dec {total_sz/1e6:7.2f}M, HEVC" if total_sz >= 1e6 else f"enc/dec {total_sz:8d}, HEVC"
|
||||
super().__init__(colored(name, "yellow"), device, Estimates(lds=total_sz, mem=total_sz))
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
|
||||
@@ -130,7 +130,7 @@ si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
|
||||
(UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: EncDec(encdec, ctx[0].nbytes, ctx[1].device)),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -71,7 +71,7 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||
base = buf_uops[1].buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = [b.buffer for b in buf_uops]
|
||||
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
|
||||
metadata = si.arg.metadata
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
@@ -18,8 +18,8 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
|
||||
# don't realize COPY/BUFFER_VIEW when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.BUFFER_VIEW} and x in ctx \
|
||||
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[x]
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
@@ -29,9 +29,9 @@ pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
])
|
||||
|
||||
@@ -174,7 +174,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
@@ -518,12 +518,11 @@ def split_store(x:UOp) -> UOp|None:
|
||||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW are cross-device or special hardware ops
|
||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||
elif stored.op is Ops.ENCDEC: ret = stored
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
||||
|
||||
@@ -3174,8 +3174,10 @@ class Tensor(OpMixin):
|
||||
the reference frames (`ref_frames`).
|
||||
"""
|
||||
ref_frames = [x.contiguous() for x in ref_frames or []]
|
||||
assert isinstance(frame_pos, Variable), "frame_pos must be a Variable"
|
||||
return self.contiguous()._apply_uop(UOp.encdec, state.contiguous(), *ref_frames, extra_args=(frame_pos,), arg=(shape,))
|
||||
assert frame_pos.op is Ops.BIND, "frame_pos must be a bound Variable"
|
||||
srcs = (out:=Tensor.empty(*shape, device=self.device, dtype=self.dtype), self.contiguous(), state.contiguous(), *ref_frames)
|
||||
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
|
||||
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)), device=self.device)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Ops(FastEnum):
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
|
||||
|
||||
# buffer ops
|
||||
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); ENCDEC = auto()
|
||||
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
|
||||
|
||||
# the core 6 movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
|
||||
@@ -229,7 +229,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.ENCDEC: return self.arg[0]
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.PARAM:
|
||||
@@ -567,7 +567,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def mstack(self, *srcs: UOp) -> UOp: return UOp(Ops.MSTACK, self.dtype, (self,)+srcs)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
|
||||
def encdec(self, *src, arg=None): return UOp(Ops.ENCDEC, self.dtype, src=(self,)+src, arg=arg)
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -903,8 +902,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return p
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp:
|
||||
# TODO: reenable this after ENCDEC is fixed
|
||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
@@ -936,7 +934,7 @@ class CallInfo:
|
||||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False
|
||||
if c.arg.precompile: return False
|
||||
return True
|
||||
|
||||
@@ -1513,7 +1511,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
|
||||
f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"),
|
||||
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
|
||||
(UPat(Ops.ENCDEC, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.encdec({''.join([str(ctx[s])+', ' for s in x.src[1:]])}arg={x.arg!r})"),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda ctx,x: f"UOp(Ops.CUSTOM_FUNCTION, {x.dtype}, src={srcs(ctx, x.src)}, arg={x.arg!r})"),
|
||||
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
|
||||
# NOTE: range has srcs sometimes after control flow
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
|
||||
|
||||
@@ -120,11 +120,10 @@ _tensor_spec = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat.var("x"),), allow_any_len=True, arg=None),
|
||||
lambda root,x: root.dtype == x.dtype and all(u.op is Ops.RANGE for u in root.src[1:])),
|
||||
|
||||
# COPY/ALLREDUCE/MULTI/ENCDEC
|
||||
# COPY/ALLREDUCE/MULTI
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
(UPat(Ops.ENCDEC, name="x"), lambda x: len(x.src) >= 2), # state + inbuffer
|
||||
|
||||
# REDUCE_AXIS is the reduce in the tensor graph
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
@@ -132,9 +131,10 @@ _tensor_spec = PatternMatcher([
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# allow CALL/PARAM
|
||||
# allow CALL/PARAM/CUSTOM_FUNCTION
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
||||
# ** for custom kernels **
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
Reference in New Issue
Block a user