hevc: fast decoder (#14057)

This commit is contained in:
nimlgen
2026-01-08 15:20:37 +03:00
committed by GitHub
parent 309197bca5
commit f3aceaa08b

View File

@@ -39,21 +39,24 @@ if __name__ == "__main__":
v_i = Variable("i", 0, len(frame_info)-1)
@TinyJit
def decode_jit(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque_nv:Tensor, i:Variable, *hist:Tensor):
return hevc_tensor[offset:offset+sz].decode_hevc_frame(pos, out_image_size, opaque_nv[i], hist).realize()
def decode_jit(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque_nv:Tensor, i:Variable, outbuf:Tensor, *hist:Tensor):
x = hevc_tensor[offset:offset+sz].decode_hevc_frame(pos, out_image_size, opaque_nv[i], hist)
outbuf.assign(x).realize()
return x
# preallocate output buffers
out_images = [Tensor.zeros(*out_image_size, dtype=dtypes.uint8, device="NV").contiguous().realize() for _ in range(len(frame_info))]
# warm up
history = [Tensor.empty(*out_image_size, dtype=dtypes.uint8, device="NV") for _ in range(max_hist)]
for i in range(3):
decode_jit(v_pos.bind(0), hevc_tensor, v_offset.bind(frame_info[0][0]), v_sz.bind(frame_info[0][1]), opaque_nv, v_i.bind(0), *history)
decode_jit(v_pos.bind(0), hevc_tensor, v_offset.bind(frame_info[0][0]), v_sz.bind(frame_info[0][1]), opaque_nv, v_i.bind(0), out_images[i], *history)
out_images = []
with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")):
for i, (offset, sz, frame_pos, history_sz, is_hist) in enumerate(frame_info):
history = history[-max_hist:] if max_hist > 0 else []
outimg = decode_jit(v_pos.bind(frame_pos), hevc_tensor, v_offset.bind(offset), v_sz.bind(sz), opaque_nv, v_i.bind(i), *history).clone().realize()
out_images.append(outimg)
if is_hist: history.append(outimg)
decode_jit(v_pos.bind(frame_pos), hevc_tensor, v_offset.bind(offset), v_sz.bind(sz), opaque_nv, v_i.bind(i), out_images[i], *history)
if is_hist: history.append(out_images[i])
Device.default.synchronize()