mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] force quantile tensors to be float; prevents accidents (#1741)
In particular, sometimes this was failing with: ``` RuntimeError: quantile() input tensor must be either float or double dtype ``` Fixes https://github.com/pytorch/pytorch/issues/103054 Signed-off-by: Edward Z. Yang <ezyang@meta.com> --------- Signed-off-by: Edward Z. Yang <ezyang@meta.com>
This commit is contained in:
@@ -86,9 +86,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
end_event[i].record()
|
||||
# Record clocks
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
|
||||
if quantiles is not None:
|
||||
ret = torch.quantile(times, torch.tensor(quantiles)).tolist()
|
||||
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
|
||||
if len(ret) == 1:
|
||||
ret = ret[0]
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user