From 877a7fdd613af2986fc40d5f6362cdfbda4ba90a Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 4 Dec 2025 11:58:34 +0300 Subject: [PATCH] jit: support encdec (#13563) * jit: support encdec * fix --- extra/hevc/decode.py | 19 ++++++++++++++----- tinygrad/engine/jit.py | 6 +++--- tinygrad/runtime/ops_nv.py | 2 +- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/extra/hevc/decode.py b/extra/hevc/decode.py index 663ba92862..a26ae520a2 100644 --- a/extra/hevc/decode.py +++ b/extra/hevc/decode.py @@ -1,7 +1,7 @@ import argparse, os, hashlib from tinygrad.helpers import getenv, DEBUG, round_up, Timing, tqdm, fetch from extra.hevc.hevc import parse_hevc_file_headers, untile_nv12, to_bgr, nv_gpu -from tinygrad import Tensor, dtypes, Device, Variable +from tinygrad import Tensor, dtypes, Device, Variable, TinyJit if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -39,15 +39,24 @@ if __name__ == "__main__": v_sz = Variable("sz", 0, hevc_tensor.numel()) v_i = Variable("i", 0, len(frame_info)-1) - history = [] + @TinyJit + def decode_jit(pos:Variable, src:Tensor, data:Tensor, *hist:Tensor): + return src.decode_hevc_frame(pos, out_image_size, data, hist).realize() + + # warm up + history = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV") for _ in range(max_hist)] + for i in range(3): + hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(frame_info[0][0])), bound_offset+v_sz.bind(frame_info[0][1])),)) + decode_jit(v_pos.bind(0), hevc_frame, opaque_nv[v_i.bind(0)], *history) + out_images = [] with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")): for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info): - history = history[-history_sz:] if history_sz > 0 else [] + history = history[-max_hist:] if max_hist > 0 else [] # TODO: this shrink should work as a slice hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(offset)), bound_offset+v_sz.bind(sz)),)) - # TODO: can this go in the JIT? - outimg = hevc_frame.decode_hevc_frame(v_pos.bind(frame_pos), out_image_size, opaque_nv[v_i.bind(i)], history).realize() + + outimg = decode_jit(v_pos.bind(frame_pos), hevc_frame, opaque_nv[v_i.bind(i)], *history).clone() out_images.append(outimg) if is_hist: history.append(outimg) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 834a401d0a..f728eb7559 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -5,7 +5,7 @@ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, from tinygrad.device import Buffer, Compiled, Device, MultiBuffer from tinygrad.dtype import DType from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops -from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates +from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from tinygrad.schedule.rangeify import mop_cleanup @@ -143,7 +143,7 @@ class MultiGraphRunner(GraphRunner): def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]: if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins] - if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])] + if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])] return [] def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): @@ -318,7 +318,7 @@ class TinyJit(Generic[ReturnType]): # memory planning (optional) # Exclude buffers involved in transfer ops to preserve parallelism. - noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs} + noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs} assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None], item.metadata, item.fixedvars) for item in jit_cache] diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 7ec7594665..4604537129 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -308,7 +308,7 @@ class NVAllocator(HCQAllocator['NVDevice']): def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf) def _encode_decode(self, bufout:HCQBuffer, bufin:HCQBuffer, desc_buf:HCQBuffer, hist:list[HCQBuffer], shape:tuple[int,...], frame_pos:int): - assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout]), "all buffers must be 0x100 aligned" + assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout, desc_buf]), "all buffers must be 0x100 aligned" h, w = ((2 * shape[0]) // 3 if shape[0] % 3 == 0 else (2 * shape[0] - 1) // 3), shape[1] self.dev._ensure_has_vid_hw(w, h)