[TESTING] better cudagraph-based benchmarking (#2394)

This commit is contained in:
Philippe Tillet
2023-09-25 21:41:26 -07:00
committed by GitHub
parent 80adbbb87b
commit eea0718445
2 changed files with 62 additions and 63 deletions

View File

@@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
"""
if torch.cuda.current_stream() == torch.cuda.default_stream():
raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
# record CUDAGraph
# warmup
fn()
# step 1 - we estimate the amount of time the kernel call takes
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
# but it is probably good enough
if grad_to_none is not None:
for x in grad_to_none:
x.detach_()
@@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
with torch.cuda.graph(g):
fn()
torch.cuda.synchronize()
fn = lambda: g.replay()
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
fn()
g.replay()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
# compute number of repetition to last `rep` ms
n_repeat = max(1, int(rep / estimate_ms))
# compute number of repetition to last `rep` ms
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
ret = []
n_retries = 50
for _ in range(n_retries):
# Benchmark
torch.cuda.synchronize()
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
# host overhead
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
torch.cuda.synchronize()
# measure time and return
ret = []
n_retries = 10
for i in range(n_retries):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
g.replay()
end_event.record()
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
ret.append(torch.min(times))
ret += [start_event.elapsed_time(end_event) / n_repeat]
return torch.mean(torch.tensor(ret)).item()