mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTING] better cudagraph-based benchmarking (#2394)
This commit is contained in:
@@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
"""
|
||||
if torch.cuda.current_stream() == torch.cuda.default_stream():
|
||||
raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
|
||||
# record CUDAGraph
|
||||
# warmup
|
||||
fn()
|
||||
# step 1 - we estimate the amount of time the kernel call takes
|
||||
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
|
||||
# but it is probably good enough
|
||||
if grad_to_none is not None:
|
||||
for x in grad_to_none:
|
||||
x.detach_()
|
||||
@@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None):
|
||||
with torch.cuda.graph(g):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
fn = lambda: g.replay()
|
||||
# Estimate the runtime of the function
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
fn()
|
||||
g.replay()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event)
|
||||
# compute number of repetition to last `rep` ms
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
# compute number of repetition to last `rep` ms
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
ret = []
|
||||
n_retries = 50
|
||||
for _ in range(n_retries):
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
|
||||
# host overhead
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for i in range(n_repeat):
|
||||
# we don't want `fn` to accumulate gradient values
|
||||
# if it contains a backward pass. So we clear the
|
||||
# provided gradients
|
||||
if grad_to_none is not None:
|
||||
for x in grad_to_none:
|
||||
x.grad = None
|
||||
# record time of `fn`
|
||||
start_event[i].record()
|
||||
fn()
|
||||
end_event[i].record()
|
||||
torch.cuda.synchronize()
|
||||
# measure time and return
|
||||
ret = []
|
||||
n_retries = 10
|
||||
for i in range(n_retries):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
g.replay()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
ret.append(torch.min(times))
|
||||
ret += [start_event.elapsed_time(end_event) / n_repeat]
|
||||
return torch.mean(torch.tensor(ret)).item()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user