mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
@@ -1,7 +1,7 @@
|
||||
import argparse, os, hashlib
|
||||
from tinygrad.helpers import getenv, DEBUG, round_up, Timing, tqdm, fetch
|
||||
from extra.hevc.hevc import parse_hevc_file_headers, untile_nv12, to_bgr, nv_gpu
|
||||
from tinygrad import Tensor, dtypes, Device, Variable
|
||||
from tinygrad import Tensor, dtypes, Device, Variable, TinyJit
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -39,15 +39,24 @@ if __name__ == "__main__":
|
||||
v_sz = Variable("sz", 0, hevc_tensor.numel())
|
||||
v_i = Variable("i", 0, len(frame_info)-1)
|
||||
|
||||
history = []
|
||||
@TinyJit
|
||||
def decode_jit(pos:Variable, src:Tensor, data:Tensor, *hist:Tensor):
|
||||
return src.decode_hevc_frame(pos, out_image_size, data, hist).realize()
|
||||
|
||||
# warm up
|
||||
history = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV") for _ in range(max_hist)]
|
||||
for i in range(3):
|
||||
hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(frame_info[0][0])), bound_offset+v_sz.bind(frame_info[0][1])),))
|
||||
decode_jit(v_pos.bind(0), hevc_frame, opaque_nv[v_i.bind(0)], *history)
|
||||
|
||||
out_images = []
|
||||
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
|
||||
for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info):
|
||||
history = history[-history_sz:] if history_sz > 0 else []
|
||||
history = history[-max_hist:] if max_hist > 0 else []
|
||||
# TODO: this shrink should work as a slice
|
||||
hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(offset)), bound_offset+v_sz.bind(sz)),))
|
||||
# TODO: can this go in the JIT?
|
||||
outimg = hevc_frame.decode_hevc_frame(v_pos.bind(frame_pos), out_image_size, opaque_nv[v_i.bind(i)], history).realize()
|
||||
|
||||
outimg = decode_jit(v_pos.bind(frame_pos), hevc_frame, opaque_nv[v_i.bind(i)], *history).clone()
|
||||
out_images.append(outimg)
|
||||
if is_hist: history.append(outimg)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv,
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.schedule.rangeify import mop_cleanup
|
||||
@@ -143,7 +143,7 @@ class MultiGraphRunner(GraphRunner):
|
||||
|
||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
|
||||
return []
|
||||
|
||||
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||
@@ -318,7 +318,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
# memory planning (optional)
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs}
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
||||
item.metadata, item.fixedvars) for item in jit_cache]
|
||||
|
||||
@@ -308,7 +308,7 @@ class NVAllocator(HCQAllocator['NVDevice']):
|
||||
def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
|
||||
|
||||
def _encode_decode(self, bufout:HCQBuffer, bufin:HCQBuffer, desc_buf:HCQBuffer, hist:list[HCQBuffer], shape:tuple[int,...], frame_pos:int):
|
||||
assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout]), "all buffers must be 0x100 aligned"
|
||||
assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout, desc_buf]), "all buffers must be 0x100 aligned"
|
||||
|
||||
h, w = ((2 * shape[0]) // 3 if shape[0] % 3 == 0 else (2 * shape[0] - 1) // 3), shape[1]
|
||||
self.dev._ensure_has_vid_hw(w, h)
|
||||
|
||||
Reference in New Issue
Block a user