[FRONTEND] Fix benchmark plotting (#2177)

This commit is contained in:
Ethan Pronovost
2023-08-24 20:34:04 -07:00
committed by GitHub
parent f6cdcf1d77
commit 56fee37a0d

View File

@@ -3,6 +3,7 @@ import os
import subprocess
import sys
from contextlib import contextmanager
from typing import Any, Dict, List
from . import language as tl
from ._C.libtriton.triton import runtime
@@ -201,37 +202,41 @@ class Benchmark:
def __init__(
self,
x_names,
x_vals,
line_arg,
line_vals,
line_names,
plot_name,
args,
xlabel='',
ylabel='',
x_log=False,
y_log=False,
x_names: List[str],
x_vals: List[Any],
line_arg: str,
line_vals: List[Any],
line_names: List[str],
plot_name: str,
args: Dict[str, Any],
xlabel: str = '',
ylabel: str = '',
x_log: bool = False,
y_log: bool = False,
color=None,
styles=None,
):
"""
Constructor
Constructor.
x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
of scalars and there are multiple x_names, all arguments will have the same value.
If x_vals is a list of tuples/lists, each element should have the same length as
x_names.
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:param x_names: Name of the arguments that should appear on the x axis of the plot.
:type x_names: List[str]
:param x_vals: List of values to use for the arguments in :code:`x_names`.
:type x_vals: List[Any]
:param line_arg: Argument name for which different values correspond to different lines in the plot.
:type line_arg: str
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
:type line_vals: List[str]
:type line_vals: List[Any]
:param line_names: Label names for the different lines.
:type line_names: List[str]
:param plot_name: Name of the plot.
:type plot_name: str
:param args: List of arguments to remain fixed throughout the benchmark.
:type args: List[str]
:param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
:type args: Dict[str, Any]
:param xlabel: Label for the x axis of the plot.
:type xlabel: str, optional
:param ylabel: Label for the y axis of the plot.
@@ -261,7 +266,7 @@ class Mark:
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
import os
import matplotlib.pyplot as plt
@@ -269,15 +274,17 @@ class Mark:
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
x_names_str = str(bench.x_names)
df = pd.DataFrame(columns=[x_names_str] + y_mean + y_min + y_max)
x_names = list(bench.x_names)
df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
for x in bench.x_vals:
if not isinstance(x, list):
x = [x]
if len(x) == 1:
x = x * len(bench.x_names)
x_str = str(x)
x_args = {x_name: x_in for x_name, x_in in zip(bench.x_names, x)}
# x can be a single value or a sequence of values.
if not isinstance(x, (list, tuple)):
x = [x for _ in x_names]
if len(x) != len(x_names):
raise ValueError(f"Expected {len(x_names)} values, got {x}")
x_args = dict(zip(x_names, x))
row_mean, row_min, row_max = [], [], []
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
@@ -288,23 +295,24 @@ class Mark:
row_mean += [y_mean]
row_min += [y_min]
row_max += [y_max]
df.loc[len(df)] = [x_str] + row_mean + row_min + row_max
df.loc[len(df)] = list(x) + row_mean + row_min + row_max
if bench.plot_name:
plt.figure()
ax = plt.subplot()
x = x_names_str
# Plot first x value on x axis if there are multiple.
first_x = x_names[0]
for i, y in enumerate(bench.line_names):
y_min, y_max = df[y + '-min'], df[y + '-max']
col = bench.styles[i][0] if bench.styles else None
sty = bench.styles[i][1] if bench.styles else None
ax.plot(df[x], df[y], label=y, color=col, ls=sty)
ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
if not y_min.isnull().all() and not y_max.isnull().all():
y_min = y_min.astype(float)
y_max = y_max.astype(float)
ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col)
ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
ax.legend()
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel)
ax.set_xlabel(bench.xlabel or first_x)
ax.set_ylabel(bench.ylabel)
# ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear")
@@ -313,7 +321,7 @@ class Mark:
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[x_names_str] + bench.line_names]
df = df[x_names + bench.line_names]
if print_data:
print(bench.plot_name + ':')
print(df)