[TESTS] Enhance benchmark flexibility (#2239)

User can pass custom arguments to benchmarks. For example, user can pass
`dtype` which will be used to create tensors in a benchmark.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
danny.jang
2023-09-12 04:31:30 +09:00
committed by GitHub
parent 5231d57c71
commit ec4a968d44

View File

@@ -266,7 +266,7 @@ class Mark:
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags):
import os
import matplotlib.pyplot as plt
@@ -287,7 +287,7 @@ class Mark:
row_mean, row_min, row_max = [], [], []
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
try:
y_mean, y_min, y_max = ret
except TypeError:
@@ -328,14 +328,14 @@ class Mark:
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, show_plots=False, print_data=False, save_path=''):
def run(self, show_plots=False, print_data=False, save_path='', **kwargs):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots, print_data)
self._run(bench, save_path, show_plots, print_data, **kwargs)
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path: