Minor edits to HBM bandwidth measurement kernel (#434)

* Change units to GiB/s from GB/s

* Run both with and w/o bounds check
This commit is contained in:
Vinayak Gokhale
2023-12-21 06:14:31 -06:00
committed by GitHub
parent 16281f02f4
commit 0248bdb29d

View File

@@ -41,14 +41,14 @@ def copy_kernel(
# and (2) enqueue the above kernel with appropriate grid/block sizes:
def copy(x: torch.Tensor, wgs=512):
def copy(x: torch.Tensor, wgs=512, bounds_check=True):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda
vector_size = output.numel()
BLOCK_SIZE = 16384
grid = (wgs, 1, 1)
BOUNDS_CHECK = True
BOUNDS_CHECK = bounds_check
# Each WG will move these many elements
n_elements = triton.cdiv(vector_size, wgs)
copy_kernel[grid](
@@ -74,7 +74,9 @@ print(
size = 2 ** 30
configs = triton.testing.Benchmark(
configs = []
for bounds_check in [True, False]:
configs.append(triton.testing.Benchmark(
x_names=['wgs'], # Argument names to use as an x-axis for the plot.
x_vals=[
(2**i) for i in range (0,12)
@@ -84,21 +86,21 @@ configs = triton.testing.Benchmark(
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
line_names=['Triton', 'Torch'], # Label name for the lines.
styles=[('blue', '-'), ('green', '-')], # Line styles.
ylabel='GB/s', # Label name for the y-axis.
plot_name=f'size={size}', # Name for the plot. Used also as a file name for saving the plot.
args={'size':size}, # Values for function arguments not in `x_names` and `y_name`.
)
ylabel='GiB/s', # Label name for the y-axis.
plot_name=f'size={size}-bounds_check={bounds_check}', # Name for the plot. Used also as a file name for saving the plot.
args={'size':size, 'bounds_check':bounds_check}, # Values for function arguments not in `x_names` and `y_name`.
))
@triton.testing.perf_report(configs)
def benchmark(size, provider, wgs):
def benchmark(size, provider, wgs, bounds_check):
x = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.clone(x), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: copy(x, wgs), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: copy(x, wgs, bounds_check), quantiles=quantiles)
# 8 because 4 bytes from load, 4 from store.
gbps = lambda ms: 8 * size / ms * 1e-6
gbps = lambda ms: 8 * size / ms * 1e3 / 1024**3
return gbps(ms), gbps(max_ms), gbps(min_ms)