mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Removed torch dependency and cleaned up testing (#1394)
`assert triton.testing.allclose` -> `torch.testing.assert_allclose` `triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.ops
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
|
||||
@@ -96,7 +97,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
|
||||
#######################
|
||||
@@ -152,7 +153,7 @@ def test_elementwise(N):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
@@ -200,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
@@ -783,7 +783,7 @@ def test_atomic_cas():
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
triton.testing.assert_almost_equal(data, ref)
|
||||
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -1214,8 +1214,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
@@ -1477,7 +1477,7 @@ def test_arange(start, device='cuda'):
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref))
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
@@ -1513,7 +1513,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
# print((output - reference_out).nonzero())
|
||||
torch.testing.assert_allclose(output, reference_out)
|
||||
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
@@ -1544,15 +1545,15 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w, out_dtype=tl.float32)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
@@ -1564,7 +1565,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
@@ -1607,7 +1608,7 @@ def test_vectorization(N):
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# np.testing.assert_allclose(dst, src[:N])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_hints", [False, True])
|
||||
|
||||
@@ -2,6 +2,34 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().to(dtype)
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||
return ref_ret, tri_ret
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -16,8 +44,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
@@ -32,9 +60,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
||||
a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
@@ -59,9 +87,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
torch.testing.assert_allclose(c_ref, c_tri)
|
||||
torch.testing.assert_allclose(da_ref, da_tri)
|
||||
torch.testing.assert_allclose(db_ref, db_tri)
|
||||
|
||||
|
||||
configs = [
|
||||
@@ -88,10 +116,10 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
a_ref, a_tri = make_pair(a_shape)
|
||||
dout_ref, dout_tri = make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
@@ -100,19 +128,19 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
out_ref = sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
torch.testing.assert_allclose(out_tri, out_ref)
|
||||
torch.testing.assert_allclose(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@@ -168,9 +196,9 @@ def test_attention_fwd_bwd(
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
triton.testing.assert_almost_equal(loss, torch_loss)
|
||||
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
torch.testing.assert_allclose(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
@@ -24,7 +25,7 @@ def test_op(M, N, dtype, mode):
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
triton.testing.assert_almost_equal(th_y, tt_y)
|
||||
torch.testing.assert_allclose(th_y, tt_y)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
@@ -35,4 +36,4 @@ def test_op(M, N, dtype, mode):
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
torch.testing.assert_allclose(th_dx, tt_dx)
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@@ -38,8 +39,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
decimal = 1 if dtype == torch.bfloat16 else 2
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
@@ -52,4 +52,4 @@ def test_normalization_with_remat():
|
||||
arg8_1 = torch.rand(64, device="cuda")
|
||||
arg9_1 = torch.rand(64, device="cuda")
|
||||
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
|
||||
triton.testing.allclose(buf16.mean(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -95,4 +96,4 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
@@ -28,7 +28,6 @@ from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from . import language
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
||||
__all__ = [
|
||||
"autotune",
|
||||
|
||||
@@ -19,7 +19,6 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
@@ -1669,8 +1668,8 @@ def _get_jsonable_constants(constants):
|
||||
def compile(fn, **kwargs):
|
||||
capability = kwargs.get("cc", None)
|
||||
if capability is None:
|
||||
device = torch.cuda.current_device()
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
# if fn is not a JITFunction, then it
|
||||
@@ -1805,7 +1804,7 @@ class CompiledKernel:
|
||||
def _init_handles(self):
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
device = torch.cuda.current_device()
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
global cuda_utils
|
||||
init_cuda_utils()
|
||||
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
@@ -1827,7 +1826,7 @@ class CompiledKernel:
|
||||
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
stream = triton.runtime.jit.get_cuda_stream()
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
|
||||
return runner
|
||||
|
||||
@@ -10,15 +10,34 @@ import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
from triton.utils import MockTensor
|
||||
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
idx = get_current_device()
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
return _cuda_getCurrentRawStream(idx)
|
||||
except ImportError:
|
||||
import torch
|
||||
return torch.cuda.current_stream(idx).cuda_stream
|
||||
|
||||
|
||||
def get_current_device():
|
||||
import torch
|
||||
return torch.cuda.current_device()
|
||||
|
||||
|
||||
def set_current_device(idx):
|
||||
import torch
|
||||
torch.cuda.set_device(idx)
|
||||
|
||||
|
||||
def get_device_capability(idx):
|
||||
import torch
|
||||
return torch.cuda.get_device_capability(idx)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
@@ -160,34 +179,31 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
||||
ty = {
|
||||
torch.bool: 'i1',
|
||||
torch.float16: 'fp16',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float32: 'fp32',
|
||||
torch.float64: 'fp64',
|
||||
torch.uint8: 'u8',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8e5: 'fp8e5',
|
||||
triton.language.float8e4: 'fp8e4',
|
||||
triton.language.float16: 'fp16',
|
||||
triton.language.bfloat16: 'bf16',
|
||||
triton.language.float32: 'fp32',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
# None are nullptr -- implicitly converted to *i8
|
||||
if key is None:
|
||||
return '*i8'
|
||||
assert isinstance(key, str)
|
||||
return key
|
||||
dtype_str = str(key).split(".")[-1]
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e5": "fp8e5",
|
||||
"float8e4": "fp8e4",
|
||||
"float16": "fp16",
|
||||
"bfloat16": "bf16",
|
||||
"float32": "fp32",
|
||||
"float64": "fp64",
|
||||
"int8": "i8",
|
||||
"int16": "i16",
|
||||
"int32": "i32",
|
||||
"int64": "i64",
|
||||
"uint8": "u8",
|
||||
"uint16": "u16",
|
||||
"uint32": "u32",
|
||||
"uint64": "u64",
|
||||
}
|
||||
# reinterpret can create triton type
|
||||
for v in list(tys.values()):
|
||||
tys[v] = v
|
||||
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
@@ -252,8 +268,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
@@ -286,7 +302,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton, "torch": torch}
|
||||
"cache": self.cache, "triton": triton,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
@@ -397,8 +415,9 @@ def jit(
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
||||
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
:note: When a jit'd function is called, arguments are
|
||||
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
||||
and a `.dtype` attribute.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
@@ -449,8 +468,8 @@ def reinterpret(tensor, dtype):
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
elif hasattr(tensor, "data_ptr"):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}. Does not contain `data_ptr` method.')
|
||||
|
||||
@@ -4,21 +4,10 @@ import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .compiler import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
has_cutlass = True
|
||||
except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
# TODO: move to separate module
|
||||
import triton
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
@@ -30,86 +19,6 @@ def catch_oor(kernel, pytest_handle=None):
|
||||
return res
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().to(dtype)
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||
return ref_ret, tri_ret
|
||||
|
||||
|
||||
def cutlass_matmul(a, b):
|
||||
if _cutlass is None:
|
||||
raise RuntimeError("Cannot find cutlass library")
|
||||
M, N = a.shape[0], b.shape[1]
|
||||
Ka, Kb = a.shape[1], b.shape[0]
|
||||
assert Ka == Kb
|
||||
assert a.dtype == b.dtype
|
||||
assert a.device == b.device
|
||||
# allocate output
|
||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||
# run function
|
||||
dtype = str(a.dtype).split('.')[-1]
|
||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
M, N, Ka,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dtype, dtype, dtype,
|
||||
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||
import numpy.testing as npt
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
x = x.cpu().detach().numpy()
|
||||
if isinstance(y, torch.Tensor):
|
||||
if y.dtype == torch.bfloat16:
|
||||
y = y.float()
|
||||
y = y.cpu().detach().numpy()
|
||||
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
|
||||
|
||||
|
||||
def allclose(x, y, atol=0, rtol=1e-2):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x)
|
||||
if not isinstance(y, torch.Tensor):
|
||||
y = torch.tensor(y)
|
||||
if x.dtype != y.dtype:
|
||||
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
|
||||
if x.shape != y.shape:
|
||||
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
|
||||
if x.dtype == torch.bool:
|
||||
return torch.sum(x ^ y) == 0
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
rtol = 0
|
||||
atol = 0
|
||||
return torch.allclose(x, y, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
@@ -122,6 +31,7 @@ def nvsmi(attrs):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
percentiles=(0.5, 0.2, 0.8),
|
||||
record_clocks=False, fast_flush=False):
|
||||
import torch
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
the 20-th and 80-th performance percentile.
|
||||
@@ -335,7 +245,7 @@ def perf_report(benchmarks):
|
||||
|
||||
def get_dram_gbps(backend=None, device=None):
|
||||
''' return DRAM bandwidth in GB/s '''
|
||||
# assert backend == CUDA
|
||||
import torch
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
@@ -346,7 +256,8 @@ def get_dram_gbps(backend=None, device=None):
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
|
||||
def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None):
|
||||
import torch
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
@@ -447,7 +358,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
|
||||
subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
|
||||
|
||||
|
||||
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||
def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
import torch
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
@@ -26,7 +24,8 @@ class MockTensor:
|
||||
"""
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if isinstance(arg, torch.dtype):
|
||||
if arg.__class__.__name__ == "dtype" and\
|
||||
arg.__module__ == "torch":
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
@@ -60,7 +59,7 @@ def reinterpret(tensor, dtype):
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
elif hasattr(tensor, "data_ptr"):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user