mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
@@ -1,7 +1,7 @@
|
|||||||
import argparse, os, hashlib
|
import argparse, os, hashlib
|
||||||
from tinygrad.helpers import getenv, DEBUG, round_up, Timing, tqdm, fetch
|
from tinygrad.helpers import getenv, DEBUG, round_up, Timing, tqdm, fetch
|
||||||
from extra.hevc.hevc import parse_hevc_file_headers, untile_nv12, to_bgr, nv_gpu
|
from extra.hevc.hevc import parse_hevc_file_headers, untile_nv12, to_bgr, nv_gpu
|
||||||
from tinygrad import Tensor, dtypes, Device, Variable
|
from tinygrad import Tensor, dtypes, Device, Variable, TinyJit
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -39,15 +39,24 @@ if __name__ == "__main__":
|
|||||||
v_sz = Variable("sz", 0, hevc_tensor.numel())
|
v_sz = Variable("sz", 0, hevc_tensor.numel())
|
||||||
v_i = Variable("i", 0, len(frame_info)-1)
|
v_i = Variable("i", 0, len(frame_info)-1)
|
||||||
|
|
||||||
history = []
|
@TinyJit
|
||||||
|
def decode_jit(pos:Variable, src:Tensor, data:Tensor, *hist:Tensor):
|
||||||
|
return src.decode_hevc_frame(pos, out_image_size, data, hist).realize()
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
history = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV") for _ in range(max_hist)]
|
||||||
|
for i in range(3):
|
||||||
|
hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(frame_info[0][0])), bound_offset+v_sz.bind(frame_info[0][1])),))
|
||||||
|
decode_jit(v_pos.bind(0), hevc_frame, opaque_nv[v_i.bind(0)], *history)
|
||||||
|
|
||||||
out_images = []
|
out_images = []
|
||||||
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
|
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
|
||||||
for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info):
|
for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info):
|
||||||
history = history[-history_sz:] if history_sz > 0 else []
|
history = history[-max_hist:] if max_hist > 0 else []
|
||||||
# TODO: this shrink should work as a slice
|
# TODO: this shrink should work as a slice
|
||||||
hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(offset)), bound_offset+v_sz.bind(sz)),))
|
hevc_frame = hevc_tensor.shrink((((bound_offset:=v_offset.bind(offset)), bound_offset+v_sz.bind(sz)),))
|
||||||
# TODO: can this go in the JIT?
|
|
||||||
outimg = hevc_frame.decode_hevc_frame(v_pos.bind(frame_pos), out_image_size, opaque_nv[v_i.bind(i)], history).realize()
|
outimg = decode_jit(v_pos.bind(frame_pos), hevc_frame, opaque_nv[v_i.bind(i)], *history).clone()
|
||||||
out_images.append(outimg)
|
out_images.append(outimg)
|
||||||
if is_hist: history.append(outimg)
|
if is_hist: history.append(outimg)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv,
|
|||||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||||
from tinygrad.dtype import DType
|
from tinygrad.dtype import DType
|
||||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
||||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||||
from tinygrad.engine.memory import _internal_memory_planner
|
from tinygrad.engine.memory import _internal_memory_planner
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.schedule.rangeify import mop_cleanup
|
from tinygrad.schedule.rangeify import mop_cleanup
|
||||||
@@ -143,7 +143,7 @@ class MultiGraphRunner(GraphRunner):
|
|||||||
|
|
||||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||||
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
|
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||||
@@ -318,7 +318,7 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
|
|
||||||
# memory planning (optional)
|
# memory planning (optional)
|
||||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs}
|
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
|
||||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
||||||
item.metadata, item.fixedvars) for item in jit_cache]
|
item.metadata, item.fixedvars) for item in jit_cache]
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ class NVAllocator(HCQAllocator['NVDevice']):
|
|||||||
def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
|
def _map(self, buf:HCQBuffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
|
||||||
|
|
||||||
def _encode_decode(self, bufout:HCQBuffer, bufin:HCQBuffer, desc_buf:HCQBuffer, hist:list[HCQBuffer], shape:tuple[int,...], frame_pos:int):
|
def _encode_decode(self, bufout:HCQBuffer, bufin:HCQBuffer, desc_buf:HCQBuffer, hist:list[HCQBuffer], shape:tuple[int,...], frame_pos:int):
|
||||||
assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout]), "all buffers must be 0x100 aligned"
|
assert all(h.va_addr % 0x100 == 0 for h in hist + [bufin, bufout, desc_buf]), "all buffers must be 0x100 aligned"
|
||||||
|
|
||||||
h, w = ((2 * shape[0]) // 3 if shape[0] % 3 == 0 else (2 * shape[0] - 1) // 3), shape[1]
|
h, w = ((2 * shape[0]) // 3 if shape[0] % 3 == 0 else (2 * shape[0] - 1) // 3), shape[1]
|
||||||
self.dev._ensure_has_vid_hw(w, h)
|
self.dev._ensure_has_vid_hw(w, h)
|
||||||
|
|||||||
Reference in New Issue
Block a user