mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix benchmark plotting (#2177)
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from . import language as tl
|
||||
from ._C.libtriton.triton import runtime
|
||||
@@ -201,37 +202,41 @@ class Benchmark:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
x_vals,
|
||||
line_arg,
|
||||
line_vals,
|
||||
line_names,
|
||||
plot_name,
|
||||
args,
|
||||
xlabel='',
|
||||
ylabel='',
|
||||
x_log=False,
|
||||
y_log=False,
|
||||
x_names: List[str],
|
||||
x_vals: List[Any],
|
||||
line_arg: str,
|
||||
line_vals: List[Any],
|
||||
line_names: List[str],
|
||||
plot_name: str,
|
||||
args: Dict[str, Any],
|
||||
xlabel: str = '',
|
||||
ylabel: str = '',
|
||||
x_log: bool = False,
|
||||
y_log: bool = False,
|
||||
color=None,
|
||||
styles=None,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
Constructor.
|
||||
x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
|
||||
of scalars and there are multiple x_names, all arguments will have the same value.
|
||||
If x_vals is a list of tuples/lists, each element should have the same length as
|
||||
x_names.
|
||||
|
||||
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
||||
:param x_names: Name of the arguments that should appear on the x axis of the plot.
|
||||
:type x_names: List[str]
|
||||
:param x_vals: List of values to use for the arguments in :code:`x_names`.
|
||||
:type x_vals: List[Any]
|
||||
:param line_arg: Argument name for which different values correspond to different lines in the plot.
|
||||
:type line_arg: str
|
||||
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
|
||||
:type line_vals: List[str]
|
||||
:type line_vals: List[Any]
|
||||
:param line_names: Label names for the different lines.
|
||||
:type line_names: List[str]
|
||||
:param plot_name: Name of the plot.
|
||||
:type plot_name: str
|
||||
:param args: List of arguments to remain fixed throughout the benchmark.
|
||||
:type args: List[str]
|
||||
:param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
|
||||
:type args: Dict[str, Any]
|
||||
:param xlabel: Label for the x axis of the plot.
|
||||
:type xlabel: str, optional
|
||||
:param ylabel: Label for the y axis of the plot.
|
||||
@@ -261,7 +266,7 @@ class Mark:
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, save_path, show_plots, print_data):
|
||||
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -269,15 +274,17 @@ class Mark:
|
||||
y_mean = bench.line_names
|
||||
y_min = [f'{x}-min' for x in bench.line_names]
|
||||
y_max = [f'{x}-max' for x in bench.line_names]
|
||||
x_names_str = str(bench.x_names)
|
||||
df = pd.DataFrame(columns=[x_names_str] + y_mean + y_min + y_max)
|
||||
x_names = list(bench.x_names)
|
||||
df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
|
||||
for x in bench.x_vals:
|
||||
if not isinstance(x, list):
|
||||
x = [x]
|
||||
if len(x) == 1:
|
||||
x = x * len(bench.x_names)
|
||||
x_str = str(x)
|
||||
x_args = {x_name: x_in for x_name, x_in in zip(bench.x_names, x)}
|
||||
# x can be a single value or a sequence of values.
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = [x for _ in x_names]
|
||||
|
||||
if len(x) != len(x_names):
|
||||
raise ValueError(f"Expected {len(x_names)} values, got {x}")
|
||||
x_args = dict(zip(x_names, x))
|
||||
|
||||
row_mean, row_min, row_max = [], [], []
|
||||
for y in bench.line_vals:
|
||||
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||
@@ -288,23 +295,24 @@ class Mark:
|
||||
row_mean += [y_mean]
|
||||
row_min += [y_min]
|
||||
row_max += [y_max]
|
||||
df.loc[len(df)] = [x_str] + row_mean + row_min + row_max
|
||||
df.loc[len(df)] = list(x) + row_mean + row_min + row_max
|
||||
|
||||
if bench.plot_name:
|
||||
plt.figure()
|
||||
ax = plt.subplot()
|
||||
x = x_names_str
|
||||
# Plot first x value on x axis if there are multiple.
|
||||
first_x = x_names[0]
|
||||
for i, y in enumerate(bench.line_names):
|
||||
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||
col = bench.styles[i][0] if bench.styles else None
|
||||
sty = bench.styles[i][1] if bench.styles else None
|
||||
ax.plot(df[x], df[y], label=y, color=col, ls=sty)
|
||||
ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
|
||||
if not y_min.isnull().all() and not y_max.isnull().all():
|
||||
y_min = y_min.astype(float)
|
||||
y_max = y_max.astype(float)
|
||||
ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col)
|
||||
ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
|
||||
ax.legend()
|
||||
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_xlabel(bench.xlabel or first_x)
|
||||
ax.set_ylabel(bench.ylabel)
|
||||
# ax.set_title(bench.plot_name)
|
||||
ax.set_xscale("log" if bench.x_log else "linear")
|
||||
@@ -313,7 +321,7 @@ class Mark:
|
||||
plt.show()
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[[x_names_str] + bench.line_names]
|
||||
df = df[x_names + bench.line_names]
|
||||
if print_data:
|
||||
print(bench.plot_name + ':')
|
||||
print(df)
|
||||
|
||||
Reference in New Issue
Block a user