[TESTING] Now using numpy instead of pytorch in triton.assert_close

More memory-efficient than pytorch
This commit is contained in:
Phil Tillet
2023-04-04 23:56:23 -07:00
parent 577cafff0a
commit 4c1d001ae4

View File

@@ -90,31 +90,37 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
def assert_close(x, y, atol=None, rtol=None, err_msg=''):
import numpy as np
import torch
# absolute tolerance hook
def default_atol(dtype):
return 1e-2
if atol is None:
atol = default_atol
atol = atol(x.dtype) if callable(atol) else atol
# relative tolerance hook
def default_rtol(dtype):
return 0.
if atol is None:
atol = default_atol
if rtol is None:
rtol = default_rtol
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
atol = atol(x.dtype) if callable(atol) else atol
rtol = rtol(x.dtype) if callable(rtol) else rtol
if x.numel() > 1 or y.numel() > 1:
# we could use a fused kernel for fast `isclose`
# if x.numel()*16 > torch.cuda.mem_get_info()[0]:
torch.testing.assert_close(x.cpu(), y.cpu(), atol=atol, rtol=rtol, equal_nan=True)
# else:
# torch.testing.assert_close(x, y, atol=atol, rtol=rtol, equal_nan=True)
# we use numpy instead of pytorch
# as it seems more memory efficient
# pytorch tends to oom on large tensors
if isinstance(x, torch.Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
if y.dtype == torch.bfloat16:
y = y.float()
y = y.cpu().detach().numpy()
if x.size > 1 or y.size > 1:
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
return
if not torch.isclose(x, y, atol=atol, rtol=rtol):
if not np.allclose(x, y, atol=atol, rtol=rtol):
raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')