From cdc48da9cdf63a11d8b28ebb47b06895fae97dcc Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:01:02 +0300 Subject: [PATCH] hevc: assert and speed (#15122) * hevc: assert and speed * simpler --- .github/workflows/benchmark.yml | 2 +- extra/hevc/decode.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 399eeae8b9..83a5262320 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -332,7 +332,7 @@ jobs: # - name: Fuzz Padded Tensor Core GEMM (PTX) # run: NV=1 NV_PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py - name: HEVC Decode Benchmark - run: VALIDATE=1 MAX_FRAMES=100 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py + run: VALIDATE=1 MAX_FRAMES=100 ASSERT_FPS=1400 JITBEAM=1 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py - name: Train MNIST run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py - name: Run 10 CIFAR training steps diff --git a/extra/hevc/decode.py b/extra/hevc/decode.py index b635590b0e..70c6df9ee9 100644 --- a/extra/hevc/decode.py +++ b/extra/hevc/decode.py @@ -10,9 +10,9 @@ HEVC_ROUNDUP = getenv("DATA_ROUNDUP", 32) @functools.cache def _hevc_jitted_decoder(out_image_size:tuple[int, int], max_hist:int, inplace:bool): def hevc_decode_frame(pos:Variable, hevc_tensor:Tensor, offset:Variable, sz:Variable, opaque:Tensor, i:Variable, *hist:Tensor, outbuf:Tensor|None=None): - x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist) + x = hevc_tensor[offset:offset+sz*HEVC_ROUNDUP].decode_hevc_frame(pos, out_image_size, opaque[i], hist).realize() if outbuf is not None: outbuf.assign(x).realize() - return x.realize() + return x return TinyJit(hevc_decode_frame) def hevc_decode(hevc_tensor:Tensor, opaque:Tensor, frame_info:list, luma_h:int, luma_w:int, @@ -74,10 +74,14 @@ if __name__ == "__main__": Device.default.synchronize() # decode all frames using the iterator - with Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")): + tm = Timing("decoding whole file: ", on_exit=(lambda et: f", {len(frame_info)} frames, {len(frame_info)/(et/1e9):.2f} fps")) + with tm: images = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w, history=hist, preallocated_outputs=out_images)) Device.default.synchronize() + fps = len(frame_info)/(tm.et/1e9) + assert fps >= getenv("ASSERT_FPS", 0), f"HEVC decode too slow: {fps:.2f} fps" + # validation if getenv("VALIDATE", 0): import pickle