mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Minor edits to HBM bandwidth measurement kernel (#434)
* Change units to GiB/s from GB/s * Run both with and w/o bounds check
This commit is contained in:
@@ -41,14 +41,14 @@ def copy_kernel(
|
||||
# and (2) enqueue the above kernel with appropriate grid/block sizes:
|
||||
|
||||
|
||||
def copy(x: torch.Tensor, wgs=512):
|
||||
def copy(x: torch.Tensor, wgs=512, bounds_check=True):
|
||||
# We need to preallocate the output.
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda
|
||||
vector_size = output.numel()
|
||||
BLOCK_SIZE = 16384
|
||||
grid = (wgs, 1, 1)
|
||||
BOUNDS_CHECK = True
|
||||
BOUNDS_CHECK = bounds_check
|
||||
# Each WG will move these many elements
|
||||
n_elements = triton.cdiv(vector_size, wgs)
|
||||
copy_kernel[grid](
|
||||
@@ -74,7 +74,9 @@ print(
|
||||
|
||||
size = 2 ** 30
|
||||
|
||||
configs = triton.testing.Benchmark(
|
||||
configs = []
|
||||
for bounds_check in [True, False]:
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['wgs'], # Argument names to use as an x-axis for the plot.
|
||||
x_vals=[
|
||||
(2**i) for i in range (0,12)
|
||||
@@ -84,21 +86,21 @@ configs = triton.testing.Benchmark(
|
||||
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
|
||||
line_names=['Triton', 'Torch'], # Label name for the lines.
|
||||
styles=[('blue', '-'), ('green', '-')], # Line styles.
|
||||
ylabel='GB/s', # Label name for the y-axis.
|
||||
plot_name=f'size={size}', # Name for the plot. Used also as a file name for saving the plot.
|
||||
args={'size':size}, # Values for function arguments not in `x_names` and `y_name`.
|
||||
)
|
||||
ylabel='GiB/s', # Label name for the y-axis.
|
||||
plot_name=f'size={size}-bounds_check={bounds_check}', # Name for the plot. Used also as a file name for saving the plot.
|
||||
args={'size':size, 'bounds_check':bounds_check}, # Values for function arguments not in `x_names` and `y_name`.
|
||||
))
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(size, provider, wgs):
|
||||
def benchmark(size, provider, wgs, bounds_check):
|
||||
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.clone(x), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: copy(x, wgs), quantiles=quantiles)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: copy(x, wgs, bounds_check), quantiles=quantiles)
|
||||
# 8 because 4 bytes from load, 4 from store.
|
||||
gbps = lambda ms: 8 * size / ms * 1e-6
|
||||
gbps = lambda ms: 8 * size / ms * 1e3 / 1024**3
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user