mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] make performance regression testing less strict (#1231)
This commit is contained in:
@@ -93,7 +93,11 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
|
||||
|
||||
|
||||
def allclose(x, y, tol=1e-2):
|
||||
def allclose(x, y, atol=0, rtol=1e-2):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x)
|
||||
if not isinstance(y, torch.Tensor):
|
||||
y = torch.tensor(y)
|
||||
if x.dtype != y.dtype:
|
||||
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
|
||||
if x.shape != y.shape:
|
||||
@@ -101,12 +105,11 @@ def allclose(x, y, tol=1e-2):
|
||||
if x.dtype == torch.bool:
|
||||
return torch.sum(x ^ y) == 0
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
tol = 0
|
||||
rtol = 0
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
y_max = torch.max(y)
|
||||
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||
return err <= tol
|
||||
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
|
||||
Reference in New Issue
Block a user