mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the default number of iterations to 5. Example: ``` Number of iterations: 5 Prefill: avg. 0.03 s, stddev 0.00 Decode: avg. 43.34 tokens/s, stdev 0.13 ``` Also remove the -2 in the number of generated tokens -- I did not find any evidence we need it.
This commit is contained in:
@@ -4,6 +4,7 @@ import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
import subprocess
|
||||
@@ -141,8 +142,8 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--microbenchmark_iterations",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of microbenchmark iterations. Default: 2.",
|
||||
default=5,
|
||||
help="Number of microbenchmark iterations. Default: 5.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbenchmark_num_tokens",
|
||||
@@ -1458,7 +1459,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(f"[DEBUG] vmfb not found")
|
||||
print(f"[DEBUG] vmfb not found (search path: {self.vicuna_vmfb_path})")
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
@@ -1754,7 +1755,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
for idx in range(self.max_num_tokens - 2):
|
||||
for idx in range(self.max_num_tokens):
|
||||
params = {
|
||||
"token": token,
|
||||
"is_first": False,
|
||||
@@ -1950,6 +1951,9 @@ if __name__ == "__main__":
|
||||
|
||||
iteration = 0
|
||||
|
||||
prefill_times = []
|
||||
avg_decode_speed = []
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
iteration += 1
|
||||
@@ -1979,11 +1983,20 @@ if __name__ == "__main__":
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
prefill_times.append(prefill_time)
|
||||
avg_decode_speed.append(tokens_per_sec)
|
||||
|
||||
print("\nResponse:", text.strip())
|
||||
print(f"\nNum tokens: {token_count}")
|
||||
print(f"Prefill: {prefill_time:.2f} seconds")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/sec")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
if args.enable_microbenchmark:
|
||||
print("\n### Final Statistics ###")
|
||||
print("Number of iterations:", iteration - 1)
|
||||
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
|
||||
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
|
||||
|
||||
Reference in New Issue
Block a user