[vicuna.py] Rework benchmark statistics calculation (#1992)

- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
This commit is contained in:
Jakub Kuderski
2023-11-23 12:04:03 -05:00
committed by GitHub
parent da50a16242
commit 2da31c4109

View File

@@ -1,4 +1,5 @@
import argparse
from dataclasses import dataclass
import json
import re
import gc
@@ -1937,6 +1938,87 @@ def create_prompt(model_name, history):
return msg
def miliseconds_to_seconds(ms: float) -> float:
return ms / 1000.0
@dataclass
class BenchmarkRunInfo:
num_prompt_tokens : int
prefill_time_ms : float
token_times_ms : list[float]
def get_prefill_speed(self) -> float:
seconds = miliseconds_to_seconds(self.prefill_time_ms)
if seconds == 0.0:
return float('inf')
return self.num_prompt_tokens / seconds
def num_generated_tokens(self) -> int:
return len(self.token_times_ms)
def get_decode_time_ms(self) -> float:
return sum(self.token_times_ms)
def get_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_decode_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds
def get_e2e_time_ms(self) -> float:
return self.prefill_time_ms + self.get_decode_time_ms()
def get_e2e_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds
def get_e2e_token_processing_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds
def print(self) -> None:
total_tokens = self.num_prompt_tokens + self.num_generated_tokens()
print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)")
print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s")
print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s")
print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)")
def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
num_iterations = len(run_infos)
print(f'Number of iterations: {num_iterations}')
if num_iterations == 0:
return
if len(run_infos) == 1:
run_infos[0].print()
return
total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens()
print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)")
def avg_and_stdev(data):
x = list(data)
return mean(x), stdev(x)
avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos)
avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos)
print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s")
avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos)
avg_decode_speed = mean(x.get_decode_speed() for x in run_infos)
print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s")
avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos)
avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos)
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
@@ -2035,8 +2117,7 @@ if __name__ == "__main__":
iteration = 0
prefill_times = []
avg_decode_speed = []
benchmark_run_infos = []
while True:
# TODO: Add break condition from user input
@@ -2052,28 +2133,27 @@ if __name__ == "__main__":
prompt = args.system_prompt + user_prompt
history = [[user_prompt, ""]]
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
prompt_token_count = len(vic.tokenizer(prompt).input_ids)
total_time_ms = 0.0 # In order to avoid divide by zero error
prefill_time_ms = 0
is_first = True
token_times_ms = []
for text, msg, exec_time in vic.generate(prompt, cli=True):
if msg is None:
if is_first:
prefill_time = exec_time
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
prefill_time_ms = exec_time * 1000
is_first = False
else:
total_time_ms += exec_time
token_count += 1
token_times_ms.append(exec_time)
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
prefill_times.append(prefill_time)
avg_decode_speed.append(tokens_per_sec)
print(f"\nResponse:\n{text.strip()}\n")
run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms)
run_info.print()
benchmark_run_infos.append(run_info)
print("\nResponse:", text.strip())
print(f"\nNum tokens: {token_count}")
print(f"Prefill: {prefill_time:.2f} seconds")
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
@@ -2081,6 +2161,4 @@ if __name__ == "__main__":
if args.enable_microbenchmark:
print("\n### Final Statistics ###")
print("Number of iterations:", iteration - 1)
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
print_aggregate_stats(benchmark_run_infos)