mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] speed up autotuning of small kernel invocations. (#1701)
Right now, `do_bench` estimates the runtime of the kernel and then uses that to run it a number of times approximately equal to 100ms (by default). However, when actually running the kernel, it also issues a `zero_()` call to clear the L2 cache. For small kernels, the `zero_()` kernel can be slower than the actual kernel we're benchmarking, causing us to badly overshoot our target latency. This has the perverse effect that very small invocations may take much longer to autotune than larger ones. By way of concrete example, before this PR, I tested the wall-clock time for the first call to `triton.ops.matmul(A, B.T)` in a process, on two `(N, N)` matrices in float32. I found that a 4k x 4k x 4k matmul warmed up in about 2.5s, but a 64 x 64 x 64 matmul took over 5 seconds! This PR fixes this issue by including the same call to `zero_()` inside our measurement loop. With this change, I find that the 4kx4kx4k and 64x64x64 matmuls warm up in very similar amounts of time, both around 2.5s. I noticed this because we tend to run tests on very small models in CI just to test code paths without regard to numerics, and found those tests were perversely taking longer than "real" models in some cases. It seems plausible that a better solution would be a pragma to disable autotuning entirely for such tests, but I think this change is a clear improvement as-is.
This commit is contained in:
@@ -40,29 +40,33 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
:type fast_flush: bool
|
||||
"""
|
||||
|
||||
# Estimate the runtime of the function
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
|
||||
# We maintain a buffer of 256 MB that we clear
|
||||
# before each kernel call to make sure that the L2
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
if fast_flush:
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
else:
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
|
||||
# Estimate the runtime of the function
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
cache.zero_()
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
fn()
|
||||
|
||||
Reference in New Issue
Block a user