mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] Enhance benchmark flexibility (#2239)
User can pass custom arguments to benchmarks. For example, user can pass `dtype` which will be used to create tensors in a benchmark. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -266,7 +266,7 @@ class Mark:
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -287,7 +287,7 @@ class Mark:
|
||||
|
||||
row_mean, row_min, row_max = [], [], []
|
||||
for y in bench.line_vals:
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
|
||||
try:
|
||||
y_mean, y_min, y_max = ret
|
||||
except TypeError:
|
||||
@@ -328,14 +328,14 @@ class Mark:
|
||||
if save_path:
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
|
||||
def run(self, show_plots=False, print_data=False, save_path=''):
|
||||
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots, print_data)
|
||||
self._run(bench, save_path, show_plots, print_data, **kwargs)
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
|
||||
Reference in New Issue
Block a user