[FRONTEND] Removed torch dependency and cleaned up testing (#1394)

`assert triton.testing.allclose` -> `torch.testing.assert_allclose`
`triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
This commit is contained in:
Philippe Tillet
2023-03-23 22:37:21 -07:00
committed by GitHub
parent ff1d0377e0
commit fc7c0b0e43
14 changed files with 152 additions and 188 deletions

View File

@@ -6,6 +6,7 @@ import torch
import triton
import triton.language as tl
import triton.ops
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
@@ -96,7 +97,7 @@ def test_matmul(M, N, K, dtype_str):
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
@@ -152,7 +153,7 @@ def test_elementwise(N):
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
# Flash-Attention
@@ -200,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

View File

@@ -783,7 +783,7 @@ def test_atomic_cas():
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
triton.testing.assert_almost_equal(data, ref)
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
# ---------------
@@ -1214,8 +1214,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# numpy result
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
@@ -1477,7 +1477,7 @@ def test_arange(start, device='cuda'):
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref))
# ---------------
# test load
@@ -1513,7 +1513,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
kernel[(1,)](input, output, input_size, output_size)
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# print((output - reference_out).nonzero())
torch.testing.assert_allclose(output, reference_out)
# Testing masked loads with an intermate copy to shared memory run.
@@ -1544,15 +1545,15 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w, out_dtype=tl.float32)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
@@ -1564,7 +1565,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
M=M, N=N, K=K)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
@@ -1607,7 +1608,7 @@ def test_vectorization(N):
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# np.testing.assert_allclose(dst, src[:N])
@pytest.mark.parametrize("has_hints", [False, True])

View File

@@ -2,6 +2,34 @@ import pytest
import torch
import triton
import triton.ops
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -16,8 +44,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
@@ -32,9 +60,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
@@ -59,9 +87,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
torch.testing.assert_allclose(c_ref, c_tri)
torch.testing.assert_allclose(da_ref, da_tri)
torch.testing.assert_allclose(db_ref, db_tri)
configs = [
@@ -88,10 +116,10 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
a_ref, a_tri = make_pair(a_shape)
dout_ref, dout_tri = make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
@@ -100,19 +128,19 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
out_ref = sparsify_tensor(out_ref, layout, BLOCK)
da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
torch.testing.assert_allclose(out_tri, out_ref)
torch.testing.assert_allclose(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@@ -168,9 +196,9 @@ def test_attention_fwd_bwd(
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
torch.testing.assert_allclose(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])

View File

@@ -2,6 +2,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize("M, N, dtype, mode",
@@ -24,7 +25,7 @@ def test_op(M, N, dtype, mode):
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
triton.testing.assert_almost_equal(th_y, tt_y)
torch.testing.assert_allclose(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
@@ -35,4 +36,4 @@ def test_op(M, N, dtype, mode):
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)
torch.testing.assert_allclose(th_dx, tt_dx)

View File

@@ -2,6 +2,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@@ -38,8 +39,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
decimal = 1 if dtype == torch.bfloat16 else 2
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)

View File

@@ -52,4 +52,4 @@ def test_normalization_with_remat():
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
triton.testing.allclose(buf16.mean(), buf14.mean().item(), atol=1e-7, rtol=0)
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)

View File

@@ -4,6 +4,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize(
@@ -95,4 +96,4 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)

View File

@@ -28,7 +28,6 @@ from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import testing
from . import ops
__all__ = [
"autotune",

View File

@@ -19,7 +19,6 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import setuptools
import torch
from filelock import FileLock
import triton
@@ -1669,8 +1668,8 @@ def _get_jsonable_constants(constants):
def compile(fn, **kwargs):
capability = kwargs.get("cc", None)
if capability is None:
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
@@ -1805,7 +1804,7 @@ class CompiledKernel:
def _init_handles(self):
if self.cu_module is not None:
return
device = torch.cuda.current_device()
device = triton.runtime.jit.get_current_device()
global cuda_utils
init_cuda_utils()
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
@@ -1827,7 +1826,7 @@ class CompiledKernel:
def runner(*args, stream=None):
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
stream = triton.runtime.jit.get_cuda_stream()
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner

View File

@@ -10,15 +10,34 @@ import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
import triton
from triton.utils import MockTensor
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
def get_cuda_stream(idx=None):
if idx is None:
idx = get_current_device()
try:
from torch._C import _cuda_getCurrentRawStream
return _cuda_getCurrentRawStream(idx)
except ImportError:
import torch
return torch.cuda.current_stream(idx).cuda_stream
def get_current_device():
import torch
return torch.cuda.current_device()
def set_current_device(idx):
import torch
torch.cuda.set_device(idx)
def get_device_capability(idx):
import torch
return torch.cuda.get_device_capability(idx)
T = TypeVar('T')
@@ -160,34 +179,31 @@ class JITFunction(KernelInterface[T]):
@staticmethod
def _type_of(key):
if isinstance(key, (torch.dtype, triton.language.dtype)):
ty = {
torch.bool: 'i1',
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float32: 'fp32',
torch.float64: 'fp64',
torch.uint8: 'u8',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8e5: 'fp8e5',
triton.language.float8e4: 'fp8e4',
triton.language.float16: 'fp16',
triton.language.bfloat16: 'bf16',
triton.language.float32: 'fp32',
}[key]
return f'*{ty}'
# None are nullptr -- implicitly converted to *i8
if key is None:
return '*i8'
assert isinstance(key, str)
return key
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
"float8e5": "fp8e5",
"float8e4": "fp8e4",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
@@ -252,8 +268,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
device = torch.cuda.current_device()
torch.cuda.set_device(device)
device = get_current_device()
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
@@ -286,7 +302,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton, "torch": torch}
"cache": self.cache, "triton": triton,
"get_current_device": get_current_device,
"set_current_device": set_current_device}
exec(src, scope)
return scope[self.fn.__name__]
@@ -397,8 +415,9 @@ def jit(
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
@@ -449,8 +468,8 @@ def reinterpret(tensor, dtype):
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
elif hasattr(tensor, "data_ptr"):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
raise TypeError(f'Cannot reinterpret a {type(tensor)}. Does not contain `data_ptr` method.')

View File

@@ -4,21 +4,10 @@ import subprocess
import sys
from contextlib import contextmanager
import torch
import triton
import triton._C.libtriton.triton as _triton
from .compiler import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass
has_cutlass = True
except ImportError:
_cutlass = None
has_cutlass = False
# TODO: move to separate module
import triton
def catch_oor(kernel, pytest_handle=None):
try:
@@ -30,86 +19,6 @@ def catch_oor(kernel, pytest_handle=None):
return res
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def cutlass_matmul(a, b):
if _cutlass is None:
raise RuntimeError("Cannot find cutlass library")
M, N = a.shape[0], b.shape[1]
Ka, Kb = a.shape[1], b.shape[0]
assert Ka == Kb
assert a.dtype == b.dtype
assert a.device == b.device
# allocate output
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function
dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
return c
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
if y.dtype == torch.bfloat16:
y = y.float()
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
def allclose(x, y, atol=0, rtol=1e-2):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
rtol = 0
atol = 0
return torch.allclose(x, y, rtol=rtol, atol=atol)
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -122,6 +31,7 @@ def nvsmi(attrs):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
record_clocks=False, fast_flush=False):
import torch
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
@@ -335,7 +245,7 @@ def perf_report(benchmarks):
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA
import torch
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
@@ -346,7 +256,8 @@ def get_dram_gbps(backend=None, device=None):
return bw_gbps
def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
import torch
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
@@ -447,7 +358,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
def get_max_simd_tflops(dtype, backend=None, device=None):
import torch
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import torch
def cdiv(x, y):
return (x + y - 1) // y
@@ -26,7 +24,8 @@ class MockTensor:
"""
@staticmethod
def wrap_dtype(arg):
if isinstance(arg, torch.dtype):
if arg.__class__.__name__ == "dtype" and\
arg.__module__ == "torch":
return MockTensor(arg)
return arg
@@ -60,7 +59,7 @@ def reinterpret(tensor, dtype):
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
elif hasattr(tensor, "data_ptr"):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else: