mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[vicuna.py] Add option to enable tracing (#1993)
This makes the program wait for tracy profiler to connect before exiting and flush profiling data after each token. I don't know how to select the tracy iree-runtime variant programatically -- instead, print an error and exit.
This commit is contained in:
@@ -4,6 +4,7 @@ import json
|
||||
import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
from tqdm import tqdm
|
||||
@@ -144,6 +145,12 @@ parser.add_argument(
|
||||
default=[],
|
||||
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tracing",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token."
|
||||
)
|
||||
|
||||
# Microbenchmarking options.
|
||||
parser.add_argument(
|
||||
@@ -2019,12 +2026,26 @@ def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
|
||||
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")
|
||||
|
||||
|
||||
def enable_tracy_tracing():
|
||||
# Make tracy wait for a caputre to be collected before exiting.
|
||||
environ["TRACY_NO_EXIT"] = "1"
|
||||
|
||||
if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy":
|
||||
print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr)
|
||||
print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = list(args.Xiree_compile)
|
||||
|
||||
device_id = None
|
||||
|
||||
if args.enable_tracing:
|
||||
enable_tracy_tracing()
|
||||
|
||||
# Process vulkan target triple.
|
||||
# TODO: This feature should just be in a common utils for other LLMs and in general
|
||||
# any model run via SHARK for Vulkan backend.
|
||||
@@ -2140,6 +2161,9 @@ if __name__ == "__main__":
|
||||
token_times_ms = []
|
||||
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if args.enable_tracing:
|
||||
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
|
||||
|
||||
if msg is None:
|
||||
if is_first:
|
||||
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
|
||||
|
||||
Reference in New Issue
Block a user