mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTING] Now using numpy instead of pytorch in triton.assert_close
More memory-efficient than pytorch
This commit is contained in:
@@ -90,31 +90,37 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
|
||||
|
||||
def assert_close(x, y, atol=None, rtol=None, err_msg=''):
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# absolute tolerance hook
|
||||
def default_atol(dtype):
|
||||
return 1e-2
|
||||
if atol is None:
|
||||
atol = default_atol
|
||||
atol = atol(x.dtype) if callable(atol) else atol
|
||||
# relative tolerance hook
|
||||
|
||||
def default_rtol(dtype):
|
||||
return 0.
|
||||
if atol is None:
|
||||
atol = default_atol
|
||||
if rtol is None:
|
||||
rtol = default_rtol
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x)
|
||||
if not isinstance(y, torch.Tensor):
|
||||
y = torch.tensor(y)
|
||||
atol = atol(x.dtype) if callable(atol) else atol
|
||||
rtol = rtol(x.dtype) if callable(rtol) else rtol
|
||||
if x.numel() > 1 or y.numel() > 1:
|
||||
# we could use a fused kernel for fast `isclose`
|
||||
# if x.numel()*16 > torch.cuda.mem_get_info()[0]:
|
||||
torch.testing.assert_close(x.cpu(), y.cpu(), atol=atol, rtol=rtol, equal_nan=True)
|
||||
# else:
|
||||
# torch.testing.assert_close(x, y, atol=atol, rtol=rtol, equal_nan=True)
|
||||
# we use numpy instead of pytorch
|
||||
# as it seems more memory efficient
|
||||
# pytorch tends to oom on large tensors
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
x = x.cpu().detach().numpy()
|
||||
if isinstance(y, torch.Tensor):
|
||||
if y.dtype == torch.bfloat16:
|
||||
y = y.float()
|
||||
y = y.cpu().detach().numpy()
|
||||
if x.size > 1 or y.size > 1:
|
||||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
|
||||
return
|
||||
if not torch.isclose(x, y, atol=atol, rtol=rtol):
|
||||
if not np.allclose(x, y, atol=atol, rtol=rtol):
|
||||
raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user