mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[vicuna.py] Rework benchmark statistics calculation (#1992)
- Move statistics out of the main loop - Add 'end-to-end' numbers - Switch the main display unit from s to ms - Start measuring time at 0 The new print format looks like this: ``` Number of iterations: 5 Num tokens: 1 (prompt), 512 (generated), 513 (total) Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt) ```
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import re
|
||||
import gc
|
||||
@@ -1937,6 +1938,87 @@ def create_prompt(model_name, history):
|
||||
return msg
|
||||
|
||||
|
||||
def miliseconds_to_seconds(ms: float) -> float:
|
||||
return ms / 1000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkRunInfo:
|
||||
num_prompt_tokens : int
|
||||
prefill_time_ms : float
|
||||
token_times_ms : list[float]
|
||||
|
||||
def get_prefill_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.prefill_time_ms)
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_prompt_tokens / seconds
|
||||
|
||||
def num_generated_tokens(self) -> int:
|
||||
return len(self.token_times_ms)
|
||||
|
||||
def get_decode_time_ms(self) -> float:
|
||||
return sum(self.token_times_ms)
|
||||
|
||||
def get_decode_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_decode_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_generated_tokens() / seconds
|
||||
|
||||
def get_e2e_time_ms(self) -> float:
|
||||
return self.prefill_time_ms + self.get_decode_time_ms()
|
||||
|
||||
def get_e2e_decode_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_generated_tokens() / seconds
|
||||
|
||||
def get_e2e_token_processing_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds
|
||||
|
||||
def print(self) -> None:
|
||||
total_tokens = self.num_prompt_tokens + self.num_generated_tokens()
|
||||
print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)")
|
||||
print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s")
|
||||
print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s")
|
||||
print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)")
|
||||
|
||||
|
||||
def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
|
||||
num_iterations = len(run_infos)
|
||||
print(f'Number of iterations: {num_iterations}')
|
||||
if num_iterations == 0:
|
||||
return
|
||||
|
||||
if len(run_infos) == 1:
|
||||
run_infos[0].print()
|
||||
return
|
||||
|
||||
total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens()
|
||||
print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)")
|
||||
|
||||
def avg_and_stdev(data):
|
||||
x = list(data)
|
||||
return mean(x), stdev(x)
|
||||
|
||||
avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos)
|
||||
avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos)
|
||||
print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s")
|
||||
|
||||
avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos)
|
||||
avg_decode_speed = mean(x.get_decode_speed() for x in run_infos)
|
||||
print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s")
|
||||
|
||||
avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos)
|
||||
avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos)
|
||||
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -2035,8 +2117,7 @@ if __name__ == "__main__":
|
||||
|
||||
iteration = 0
|
||||
|
||||
prefill_times = []
|
||||
avg_decode_speed = []
|
||||
benchmark_run_infos = []
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
@@ -2052,28 +2133,27 @@ if __name__ == "__main__":
|
||||
prompt = args.system_prompt + user_prompt
|
||||
history = [[user_prompt, ""]]
|
||||
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
prompt_token_count = len(vic.tokenizer(prompt).input_ids)
|
||||
total_time_ms = 0.0 # In order to avoid divide by zero error
|
||||
prefill_time_ms = 0
|
||||
is_first = True
|
||||
token_times_ms = []
|
||||
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
|
||||
prefill_time_ms = exec_time * 1000
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
token_times_ms.append(exec_time)
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
prefill_times.append(prefill_time)
|
||||
avg_decode_speed.append(tokens_per_sec)
|
||||
print(f"\nResponse:\n{text.strip()}\n")
|
||||
run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms)
|
||||
run_info.print()
|
||||
benchmark_run_infos.append(run_info)
|
||||
|
||||
print("\nResponse:", text.strip())
|
||||
print(f"\nNum tokens: {token_count}")
|
||||
print(f"Prefill: {prefill_time:.2f} seconds")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
@@ -2081,6 +2161,4 @@ if __name__ == "__main__":
|
||||
|
||||
if args.enable_microbenchmark:
|
||||
print("\n### Final Statistics ###")
|
||||
print("Number of iterations:", iteration - 1)
|
||||
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
|
||||
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
|
||||
print_aggregate_stats(benchmark_run_infos)
|
||||
|
||||
Reference in New Issue
Block a user