mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
assert benchmark times (#12042)
* assert jitted times in openpilot * better error * better error * add ASSERT_MIN_STEP_TIME to more models * t is step_times * update benchmark times * update times
This commit is contained in:
52
.github/workflows/benchmark.yml
vendored
52
.github/workflows/benchmark.yml
vendored
@@ -52,14 +52,14 @@ jobs:
|
|||||||
- name: reset process replay
|
- name: reset process replay
|
||||||
run: python3.11 test/external/process_replay/reset.py
|
run: python3.11 test/external/process_replay/reset.py
|
||||||
- name: Run Stable Diffusion
|
- name: Run Stable Diffusion
|
||||||
run: BENCHMARK_LOG=stable_diffusion JIT=1 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=500 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||||
- name: Run Stable Diffusion without fp16
|
- name: Run Stable Diffusion without fp16
|
||||||
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
|
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=700 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
|
||||||
- name: Run Stable Diffusion v2
|
- name: Run Stable Diffusion v2
|
||||||
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
|
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=1600 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
|
||||||
# process replay can't capture this, the graph is too large
|
# process replay can't capture this, the graph is too large
|
||||||
- name: Run SDXL
|
- name: Run SDXL
|
||||||
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||||
- name: Run model inference benchmark
|
- name: Run model inference benchmark
|
||||||
run: METAL=1 python3.11 test/external/external_model_benchmark.py
|
run: METAL=1 python3.11 test/external/external_model_benchmark.py
|
||||||
- name: Test speed vs torch
|
- name: Test speed vs torch
|
||||||
@@ -99,7 +99,7 @@ jobs:
|
|||||||
- name: Run GPT2
|
- name: Run GPT2
|
||||||
run: |
|
run: |
|
||||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||||
BENCHMARK_LOG=gpt2 JIT=1 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=8 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||||
- name: Run GPT2 w HALF
|
- name: Run GPT2 w HALF
|
||||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||||
- name: Run GPT2 w HALF/BEAM
|
- name: Run GPT2 w HALF/BEAM
|
||||||
@@ -109,13 +109,13 @@ jobs:
|
|||||||
- name: Train MNIST
|
- name: Train MNIST
|
||||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||||
- name: Run 10 CIFAR training steps
|
- name: Run 10 CIFAR training steps
|
||||||
run: BENCHMARK_LOG=cifar_10steps JIT=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
|
run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=320 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||||
- name: Run 10 CIFAR training steps w HALF
|
- name: Run 10 CIFAR training steps w HALF
|
||||||
run: BENCHMARK_LOG=cifar_10steps_half JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=385 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||||
#- name: Run 10 CIFAR training steps w BF16
|
#- name: Run 10 CIFAR training steps w BF16
|
||||||
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||||
- name: Run 10 CIFAR training steps w winograd
|
- name: Run 10 CIFAR training steps w winograd
|
||||||
run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||||
- name: UsbGPU boot time
|
- name: UsbGPU boot time
|
||||||
run: sudo -E PYTHONPATH=. DEBUG=2 AM_RESET=1 AMD=1 AMD_IFACE=USB time python3.11 test/test_tiny.py TestTiny.test_plus
|
run: sudo -E PYTHONPATH=. DEBUG=2 AM_RESET=1 AMD=1 AMD_IFACE=USB time python3.11 test/test_tiny.py TestTiny.test_plus
|
||||||
- name: UsbGPU tiny tests
|
- name: UsbGPU tiny tests
|
||||||
@@ -214,7 +214,7 @@ jobs:
|
|||||||
- name: Run Stable Diffusion
|
- name: Run Stable Diffusion
|
||||||
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||||
- name: Run SDXL
|
- name: Run SDXL
|
||||||
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||||
- name: Run LLaMA
|
- name: Run LLaMA
|
||||||
run: |
|
run: |
|
||||||
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||||
@@ -238,9 +238,9 @@ jobs:
|
|||||||
- name: Run GPT2
|
- name: Run GPT2
|
||||||
run: |
|
run: |
|
||||||
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||||
BENCHMARK_LOG=gpt2 NV=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||||
- name: Run GPT2 w HALF
|
- name: Run GPT2 w HALF
|
||||||
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||||
- name: Run GPT2 w HALF/BEAM
|
- name: Run GPT2 w HALF/BEAM
|
||||||
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
@@ -306,13 +306,13 @@ jobs:
|
|||||||
- name: Train MNIST
|
- name: Train MNIST
|
||||||
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||||
- name: Run 10 CIFAR training steps
|
- name: Run 10 CIFAR training steps
|
||||||
run: BENCHMARK_LOG=cifar_10steps NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=85 NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||||
- name: Run 10 CIFAR training steps w HALF
|
- name: Run 10 CIFAR training steps w HALF
|
||||||
run: BENCHMARK_LOG=cifar_10steps_half NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=68 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||||
- name: Run 10 CIFAR training steps w BF16
|
- name: Run 10 CIFAR training steps w BF16
|
||||||
run: BENCHMARK_LOG=cifar_10steps_bf16 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=75 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||||
- name: Run 10 CIFAR training steps w winograd
|
- name: Run 10 CIFAR training steps w winograd
|
||||||
run: BENCHMARK_LOG=cifar_10steps_half_wino NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=35 NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||||
- name: Run full CIFAR training w 1 GPU
|
- name: Run full CIFAR training w 1 GPU
|
||||||
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
||||||
- name: Run full CIFAR training steps w 6 GPUS
|
- name: Run full CIFAR training steps w 6 GPUS
|
||||||
@@ -415,9 +415,9 @@ jobs:
|
|||||||
- name: Test AM warm start time
|
- name: Test AM warm start time
|
||||||
run: time AMD=1 python3 test/test_tiny.py TestTiny.test_plus
|
run: time AMD=1 python3 test/test_tiny.py TestTiny.test_plus
|
||||||
- name: Run Stable Diffusion
|
- name: Run Stable Diffusion
|
||||||
run: BENCHMARK_LOG=stable_diffusion AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=450 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||||
- name: Run SDXL
|
- name: Run SDXL
|
||||||
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=1400 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||||
- name: Run LLaMA 7B
|
- name: Run LLaMA 7B
|
||||||
run: |
|
run: |
|
||||||
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||||
@@ -443,9 +443,9 @@ jobs:
|
|||||||
- name: Run GPT2
|
- name: Run GPT2
|
||||||
run: |
|
run: |
|
||||||
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||||
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||||
- name: Run GPT2 w HALF
|
- name: Run GPT2 w HALF
|
||||||
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||||
- name: Run GPT2 w HALF/BEAM
|
- name: Run GPT2 w HALF/BEAM
|
||||||
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
@@ -508,13 +508,13 @@ jobs:
|
|||||||
- name: Train MNIST
|
- name: Train MNIST
|
||||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||||
- name: Run 10 CIFAR training steps
|
- name: Run 10 CIFAR training steps
|
||||||
run: BENCHMARK_LOG=cifar_10steps AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=85 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||||
- name: Run 10 CIFAR training steps w HALF
|
- name: Run 10 CIFAR training steps w HALF
|
||||||
run: BENCHMARK_LOG=cifar_10steps_half AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=188 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||||
- name: Run 10 CIFAR training steps w BF16
|
- name: Run 10 CIFAR training steps w BF16
|
||||||
run: BENCHMARK_LOG=cifar_10steps_bf16 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||||
- name: Run 10 CIFAR training steps w winograd
|
- name: Run 10 CIFAR training steps w winograd
|
||||||
run: BENCHMARK_LOG=cifar_10steps_half_wino AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||||
- name: Run full CIFAR training w 1 GPU
|
- name: Run full CIFAR training w 1 GPU
|
||||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
||||||
#- name: Run full CIFAR training steps w 6 GPUS
|
#- name: Run full CIFAR training steps w 6 GPUS
|
||||||
@@ -606,11 +606,11 @@ jobs:
|
|||||||
- name: reset process replay
|
- name: reset process replay
|
||||||
run: test/external/process_replay/reset.py
|
run: test/external/process_replay/reset.py
|
||||||
- name: benchmark openpilot 0.9.9 driving_vision
|
- name: benchmark openpilot 0.9.9 driving_vision
|
||||||
run: BENCHMARK_LOG=openpilot_0_9_9_vision PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
|
run: BENCHMARK_LOG=openpilot_0_9_9_vision ASSERT_MIN_STEP_TIME=30 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
|
||||||
- name: benchmark openpilot 0.9.9 driving_policy
|
- name: benchmark openpilot 0.9.9 driving_policy
|
||||||
run: BENCHMARK_LOG=openpilot_0_9_9_policy PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx
|
run: BENCHMARK_LOG=openpilot_0_9_9_policy ASSERT_MIN_STEP_TIME=45 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx
|
||||||
- name: benchmark openpilot 0.9.9 dmonitoring
|
- name: benchmark openpilot 0.9.9 dmonitoring
|
||||||
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
|
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring ASSERT_MIN_STEP_TIME=70 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||||
- name: openpilot compile3 0.9.9 driving_vision
|
- name: openpilot compile3 0.9.9 driving_vision
|
||||||
run: PYTHONPATH="." QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
|
run: PYTHONPATH="." QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
|
||||||
- name: openpilot compile3 0.9.9 driving_policy
|
- name: openpilot compile3 0.9.9 driving_policy
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ class GPT2:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
||||||
|
step_times = []
|
||||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
@@ -197,8 +198,13 @@ class GPT2:
|
|||||||
else:
|
else:
|
||||||
tokens = Tensor([x[start_pos:] for x in toks])
|
tokens = Tensor([x[start_pos:] for x in toks])
|
||||||
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT-1).bind(start_pos), temperature).tolist()
|
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT-1).bind(start_pos), temperature).tolist()
|
||||||
|
step_times.append((GlobalCounters.time_sum_s-st)*1e3)
|
||||||
start_pos = len(toks[0])
|
start_pos = len(toks[0])
|
||||||
for i,t in enumerate(tok): toks[i].append(t)
|
for i,t in enumerate(tok): toks[i].append(t)
|
||||||
|
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
return [self.tokenizer.decode(x) for x in toks]
|
return [self.tokenizer.decode(x) for x in toks]
|
||||||
|
|
||||||
# **** main code ****
|
# **** main code ****
|
||||||
|
|||||||
@@ -355,7 +355,7 @@ def train_cifar():
|
|||||||
|
|
||||||
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
|
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
|
||||||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||||
|
step_times = []
|
||||||
model_ema: Optional[modelEMA] = None
|
model_ema: Optional[modelEMA] = None
|
||||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
||||||
i = 0
|
i = 0
|
||||||
@@ -413,12 +413,17 @@ def train_cifar():
|
|||||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||||
|
|
||||||
cl = time.monotonic()
|
cl = time.monotonic()
|
||||||
|
step_times.append((cl-st)*1000.0)
|
||||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||||
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
|
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
|
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
|
||||||
st = cl
|
st = cl
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
|
|
||||||
# verify eval acc
|
# verify eval acc
|
||||||
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
||||||
if eval_acc_pct >= target:
|
if eval_acc_pct >= target:
|
||||||
|
|||||||
@@ -252,6 +252,10 @@ def train_resnet():
|
|||||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||||
# if we are doing beam search, run the first eval too
|
# if we are doing beam search, run the first eval too
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
|
|
||||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||||
return
|
return
|
||||||
if MLLOGGER and RUNMLPERF:
|
if MLLOGGER and RUNMLPERF:
|
||||||
@@ -344,6 +348,8 @@ def train_resnet():
|
|||||||
print(f"saving ckpt to {fn}")
|
print(f"saving ckpt to {fn}")
|
||||||
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_retinanet():
|
def train_retinanet():
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from examples.mlperf.dataloader import batch_load_retinanet
|
from examples.mlperf.dataloader import batch_load_retinanet
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
from tinygrad import Tensor, TinyJit, dtypes, GlobalCounters
|
from tinygrad import Tensor, TinyJit, dtypes, GlobalCounters
|
||||||
from tinygrad.nn import Conv2d, GroupNorm
|
from tinygrad.nn import Conv2d, GroupNorm
|
||||||
from tinygrad.nn.state import safe_load, load_state_dict
|
from tinygrad.nn.state import safe_load, load_state_dict
|
||||||
from tinygrad.helpers import fetch, trange, colored, Timing
|
from tinygrad.helpers import fetch, trange, colored, Timing, getenv
|
||||||
from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
|
from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
|
||||||
from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
|
from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
|
||||||
from extra.bench_log import BenchEvent, WallTimeEvent
|
from extra.bench_log import BenchEvent, WallTimeEvent
|
||||||
@@ -14,7 +14,7 @@ from examples.stable_diffusion import ResnetBlock, Mid
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
|
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
|
||||||
import argparse, tempfile
|
import argparse, tempfile, time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -342,11 +342,13 @@ class DPMPP2MSampler:
|
|||||||
sigmas = self.discretization(num_steps).to(x.device)
|
sigmas = self.discretization(num_steps).to(x.device)
|
||||||
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
|
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
num_sigmas = len(sigmas)
|
num_sigmas = len(sigmas)
|
||||||
|
step_times = []
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
for i in trange(num_sigmas - 1):
|
for i in trange(num_sigmas - 1):
|
||||||
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
|
st = time.perf_counter_ns()
|
||||||
with WallTimeEvent(BenchEvent.STEP):
|
with WallTimeEvent(BenchEvent.STEP):
|
||||||
x, old_denoised = self.sampler_step(
|
x, old_denoised = self.sampler_step(
|
||||||
old_denoised=old_denoised,
|
old_denoised=old_denoised,
|
||||||
@@ -358,8 +360,13 @@ class DPMPP2MSampler:
|
|||||||
c=c,
|
c=c,
|
||||||
uc=uc,
|
uc=uc,
|
||||||
)
|
)
|
||||||
|
step_times.append(t:=(time.perf_counter_ns() - st)*1e-6)
|
||||||
x.realize(old_denoised)
|
x.realize(old_denoised)
|
||||||
|
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
import argparse, time
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
@@ -266,17 +266,23 @@ if __name__ == "__main__":
|
|||||||
def run(model, *x): return model(*x).realize()
|
def run(model, *x): return model(*x).realize()
|
||||||
|
|
||||||
# this is diffusion
|
# this is diffusion
|
||||||
|
step_times = []
|
||||||
with Context(BEAM=getenv("LATEBEAM")):
|
with Context(BEAM=getenv("LATEBEAM")):
|
||||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
|
st = time.perf_counter_ns()
|
||||||
t.set_description("%3d %3d" % (index, timestep))
|
t.set_description("%3d %3d" % (index, timestep))
|
||||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||||
with WallTimeEvent(BenchEvent.STEP):
|
with WallTimeEvent(BenchEvent.STEP):
|
||||||
tid = Tensor([index])
|
tid = Tensor([index])
|
||||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||||
|
step_times.append((time.perf_counter_ns() - st)*1e-6)
|
||||||
del run
|
del run
|
||||||
|
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
# upsample latent space to image with autoencoder
|
# upsample latent space to image with autoencoder
|
||||||
x = model.decode(latent)
|
x = model.decode(latent)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# NOTE: the inputs to a JIT must be first level arguments
|
# NOTE: the inputs to a JIT must be first level arguments
|
||||||
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
|
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
|
||||||
|
step_times = []
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
st = time.perf_counter_ns()
|
st = time.perf_counter_ns()
|
||||||
@@ -35,7 +36,12 @@ if __name__ == "__main__":
|
|||||||
inputs = {**{k:v for k,v in new_inputs_junk.items() if 'img' in k},
|
inputs = {**{k:v for k,v in new_inputs_junk.items() if 'img' in k},
|
||||||
**{k:Tensor(v) for k,v in new_inputs_junk_numpy.items() if 'img' not in k}}
|
**{k:Tensor(v) for k,v in new_inputs_junk_numpy.items() if 'img' not in k}}
|
||||||
ret = next(iter(run_onnx_jit(**inputs).values())).cast(dtypes.float32).numpy()
|
ret = next(iter(run_onnx_jit(**inputs).values())).cast(dtypes.float32).numpy()
|
||||||
print(f"jitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")
|
step_times.append(t:=(time.perf_counter_ns() - st)*1e-6)
|
||||||
|
print(f"jitted: {t:7.4f} ms")
|
||||||
|
|
||||||
|
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||||
|
min_time = min(step_times)
|
||||||
|
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||||
|
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if IMAGE.value < 2: suffix += f"_image{IMAGE.value}" # image=2 has no suffix for compatibility
|
if IMAGE.value < 2: suffix += f"_image{IMAGE.value}" # image=2 has no suffix for compatibility
|
||||||
|
|||||||
Reference in New Issue
Block a user