[vicuna.py] Add option to enable tracing (#1993)

This makes the program wait for tracy profiler to connect before exiting
and flush profiling data after each token.

I don't know how to select the tracy iree-runtime variant
programatically -- instead, print an error and exit.
This commit is contained in:
Jakub Kuderski
2023-11-24 15:25:03 -05:00
committed by GitHub
parent 2da31c4109
commit 1f5b39f56e

View File

@@ -4,6 +4,7 @@ import json
import re
import gc
from io import BytesIO
from os import environ
from pathlib import Path
from statistics import mean, stdev
from tqdm import tqdm
@@ -144,6 +145,12 @@ parser.add_argument(
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
)
parser.add_argument(
"--enable_tracing",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token."
)
# Microbenchmarking options.
parser.add_argument(
@@ -2019,12 +2026,26 @@ def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")
def enable_tracy_tracing():
# Make tracy wait for a caputre to be collected before exiting.
environ["TRACY_NO_EXIT"] = "1"
if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy":
print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr)
print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
_extra_args = list(args.Xiree_compile)
device_id = None
if args.enable_tracing:
enable_tracy_tracing()
# Process vulkan target triple.
# TODO: This feature should just be in a common utils for other LLMs and in general
# any model run via SHARK for Vulkan backend.
@@ -2140,6 +2161,9 @@ if __name__ == "__main__":
token_times_ms = []
for text, msg, exec_time in vic.generate(prompt, cli=True):
if args.enable_tracing:
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
if msg is None:
if is_first:
# Note that the prefill time is in seconds, and all the decoded tokens in ms.