assert benchmark times (#12042)

* assert jitted times in openpilot

* better error

* better error

* add ASSERT_MIN_STEP_TIME to more models

* t is step_times

* update benchmark times

* update times
This commit is contained in:
Sieds Lykles
2025-09-09 23:40:02 +02:00
committed by GitHub
parent 58d13a6e3e
commit 5b73076e48
7 changed files with 67 additions and 31 deletions

View File

@@ -52,14 +52,14 @@ jobs:
- name: reset process replay
run: python3.11 test/external/process_replay/reset.py
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion JIT=1 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=500 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run Stable Diffusion without fp16
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=700 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
- name: Run Stable Diffusion v2
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=1600 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
# process replay can't capture this, the graph is too large
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run model inference benchmark
run: METAL=1 python3.11 test/external/external_model_benchmark.py
- name: Test speed vs torch
@@ -99,7 +99,7 @@ jobs:
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 JIT=1 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=8 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
@@ -109,13 +109,13 @@ jobs:
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps JIT=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=320 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=385 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
#- name: Run 10 CIFAR training steps w BF16
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run 10 CIFAR training steps w winograd
run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
- name: UsbGPU boot time
run: sudo -E PYTHONPATH=. DEBUG=2 AM_RESET=1 AMD=1 AMD_IFACE=USB time python3.11 test/test_tiny.py TestTiny.test_plus
- name: UsbGPU tiny tests
@@ -214,7 +214,7 @@ jobs:
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@@ -238,9 +238,9 @@ jobs:
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 NV=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
@@ -306,13 +306,13 @@ jobs:
- name: Train MNIST
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=85 NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=68 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
- name: Run 10 CIFAR training steps w BF16
run: BENCHMARK_LOG=cifar_10steps_bf16 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=75 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run 10 CIFAR training steps w winograd
run: BENCHMARK_LOG=cifar_10steps_half_wino NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=35 NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
- name: Run full CIFAR training steps w 6 GPUS
@@ -415,9 +415,9 @@ jobs:
- name: Test AM warm start time
run: time AMD=1 python3 test/test_tiny.py TestTiny.test_plus
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=450 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=1400 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run LLaMA 7B
run: |
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@@ -443,9 +443,9 @@ jobs:
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
@@ -508,13 +508,13 @@ jobs:
- name: Train MNIST
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=85 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=188 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
- name: Run 10 CIFAR training steps w BF16
run: BENCHMARK_LOG=cifar_10steps_bf16 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run 10 CIFAR training steps w winograd
run: BENCHMARK_LOG=cifar_10steps_half_wino AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
#- name: Run full CIFAR training steps w 6 GPUS
@@ -606,11 +606,11 @@ jobs:
- name: reset process replay
run: test/external/process_replay/reset.py
- name: benchmark openpilot 0.9.9 driving_vision
run: BENCHMARK_LOG=openpilot_0_9_9_vision PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
run: BENCHMARK_LOG=openpilot_0_9_9_vision ASSERT_MIN_STEP_TIME=30 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
- name: benchmark openpilot 0.9.9 driving_policy
run: BENCHMARK_LOG=openpilot_0_9_9_policy PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx
run: BENCHMARK_LOG=openpilot_0_9_9_policy ASSERT_MIN_STEP_TIME=45 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_policy.onnx
- name: benchmark openpilot 0.9.9 dmonitoring
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring ASSERT_MIN_STEP_TIME=70 PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
- name: openpilot compile3 0.9.9 driving_vision
run: PYTHONPATH="." QCOM=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.9.9 driving_policy

View File

@@ -181,6 +181,7 @@ class GPT2:
self.tokenizer = tokenizer
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
step_times = []
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
toks = [prompt_tokens[:] for _ in range(batch_size)]
start_pos = 0
@@ -197,8 +198,13 @@ class GPT2:
else:
tokens = Tensor([x[start_pos:] for x in toks])
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT-1).bind(start_pos), temperature).tolist()
step_times.append((GlobalCounters.time_sum_s-st)*1e3)
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
return [self.tokenizer.decode(x) for x in toks]
# **** main code ****

View File

@@ -355,7 +355,7 @@ def train_cifar():
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
step_times = []
model_ema: Optional[modelEMA] = None
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
i = 0
@@ -413,12 +413,17 @@ def train_cifar():
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
step_times.append((cl-st)*1000.0)
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
st = cl
i += 1
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if eval_acc_pct >= target:

View File

@@ -252,6 +252,10 @@ def train_resnet():
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
# if we are doing beam search, run the first eval too
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
if MLLOGGER and RUNMLPERF:
@@ -344,6 +348,8 @@ def train_resnet():
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
from contextlib import redirect_stdout
from examples.mlperf.dataloader import batch_load_retinanet

View File

@@ -6,7 +6,7 @@
from tinygrad import Tensor, TinyJit, dtypes, GlobalCounters
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, trange, colored, Timing
from tinygrad.helpers import fetch, trange, colored, Timing, getenv
from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
from extra.bench_log import BenchEvent, WallTimeEvent
@@ -14,7 +14,7 @@ from examples.stable_diffusion import ResnetBlock, Mid
import numpy as np
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
import argparse, tempfile
import argparse, tempfile, time
from abc import ABC, abstractmethod
from pathlib import Path
from PIL import Image
@@ -342,11 +342,13 @@ class DPMPP2MSampler:
sigmas = self.discretization(num_steps).to(x.device)
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
step_times = []
old_denoised = None
for i in trange(num_sigmas - 1):
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
GlobalCounters.reset()
st = time.perf_counter_ns()
with WallTimeEvent(BenchEvent.STEP):
x, old_denoised = self.sampler_step(
old_denoised=old_denoised,
@@ -358,8 +360,13 @@ class DPMPP2MSampler:
c=c,
uc=uc,
)
step_times.append(t:=(time.perf_counter_ns() - st)*1e-6)
x.realize(old_denoised)
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
return x

View File

@@ -2,7 +2,7 @@
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
import tempfile
from pathlib import Path
import argparse
import argparse, time
from collections import namedtuple
from typing import Dict, Any
@@ -266,17 +266,23 @@ if __name__ == "__main__":
def run(model, *x): return model(*x).realize()
# this is diffusion
step_times = []
with Context(BEAM=getenv("LATEBEAM")):
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
st = time.perf_counter_ns()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
with WallTimeEvent(BenchEvent.STEP):
tid = Tensor([index])
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
step_times.append((time.perf_counter_ns() - st)*1e-6)
del run
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
# upsample latent space to image with autoencoder
x = model.decode(latent)
print(x.shape)

View File

@@ -27,6 +27,7 @@ if __name__ == "__main__":
# NOTE: the inputs to a JIT must be first level arguments
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
step_times = []
for _ in range(20):
GlobalCounters.reset()
st = time.perf_counter_ns()
@@ -35,7 +36,12 @@ if __name__ == "__main__":
inputs = {**{k:v for k,v in new_inputs_junk.items() if 'img' in k},
**{k:Tensor(v) for k,v in new_inputs_junk_numpy.items() if 'img' not in k}}
ret = next(iter(run_onnx_jit(**inputs).values())).cast(dtypes.float32).numpy()
print(f"jitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")
step_times.append(t:=(time.perf_counter_ns() - st)*1e-6)
print(f"jitted: {t:7.4f} ms")
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
suffix = ""
if IMAGE.value < 2: suffix += f"_image{IMAGE.value}" # image=2 has no suffix for compatibility