[FRONTEND] speed up autotuning of small kernel invocations. (#1701)

Right now, `do_bench` estimates the runtime of the kernel and then uses
that to run it a number of times approximately equal to 100ms (by
default).

However, when actually running the kernel, it also issues a `zero_()`
call to clear the L2 cache. For small kernels, the `zero_()` kernel can
be slower than the actual kernel we're benchmarking, causing us to badly
overshoot our target latency.

This has the perverse effect that very small invocations may take much
longer to autotune than larger ones. By way of concrete example, before
this PR, I tested the wall-clock time for the first call to
`triton.ops.matmul(A, B.T)` in a process, on two `(N, N)` matrices in
float32. I found that a 4k x 4k x 4k matmul warmed up in about 2.5s, but
a 64 x 64 x 64 matmul took over 5 seconds!

This PR fixes this issue by including the same call to `zero_()` inside
our measurement loop.

With this change, I find that the 4kx4kx4k and 64x64x64 matmuls warm up
in very similar amounts of time, both around 2.5s.

I noticed this because we tend to run tests on very small models in CI
just to test code paths without regard to numerics, and found those
tests were perversely taking longer than "real" models in some cases. It
seems plausible that a better solution would be a pragma to disable
autotuning entirely for such tests, but I think this change is a clear
improvement as-is.
This commit is contained in:
Nelson Elhage
2023-05-26 09:42:31 -07:00
committed by GitHub
parent 0341953466
commit 0274446b3a

View File

@@ -40,29 +40,33 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
:type fast_flush: bool
"""
# Estimate the runtime of the function
fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()