mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] interpreter rewrite (#2321)
This is a new interpreter mode that shares semantic analysis with the JIT'ed codepath and that the Triton core team is committed to maintain
This commit is contained in:
@@ -59,8 +59,8 @@ class Package(NamedTuple):
|
||||
|
||||
|
||||
def get_pybind11_package_info():
|
||||
name = "pybind11-2.10.0"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||
name = "pybind11-2.11.1"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
|
||||
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
# llvm
|
||||
@@ -296,7 +296,6 @@ setup(
|
||||
"triton/_C",
|
||||
"triton/common",
|
||||
"triton/compiler",
|
||||
"triton/interpreter",
|
||||
"triton/language",
|
||||
"triton/language/extra",
|
||||
"triton/ops",
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy);
|
||||
@@ -1961,11 +1962,52 @@ void init_triton_translation(py::module &m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton_interpreter(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("load",
|
||||
[](py::array_t<uint64_t> ptrs, py::array_t<bool> masks, py::array other,
|
||||
py::dtype ret_dtype) -> py::array {
|
||||
int numel = ptrs.size();
|
||||
auto shape =
|
||||
std::vector<ptrdiff_t>(ptrs.shape(), ptrs.shape() + ptrs.ndim());
|
||||
py::array ret(ret_dtype, py::array::ShapeContainer{numel});
|
||||
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
|
||||
py::array_t<bool> reshaped_masks = masks.reshape({numel});
|
||||
py::array reshaped_others = other.reshape({numel});
|
||||
for (size_t i = 0; i < ptrs.size(); ++i) {
|
||||
if (reshaped_masks.at(i))
|
||||
memcpy(ret.mutable_data(i),
|
||||
reinterpret_cast<void *>(reshaped_ptrs.at(i)),
|
||||
ret_dtype.itemsize());
|
||||
else
|
||||
memcpy(ret.mutable_data(i), reshaped_others.data(i),
|
||||
ret_dtype.itemsize());
|
||||
}
|
||||
return ret.reshape(shape);
|
||||
});
|
||||
|
||||
m.def("store", [](py::array_t<uint64_t> ptrs, py::array values,
|
||||
py::array_t<bool> mask) {
|
||||
int numel = ptrs.size();
|
||||
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
|
||||
py::array_t<int8_t> reshaped_masks = mask.reshape({numel});
|
||||
py::array reshaped_values = values.reshape({numel});
|
||||
for (size_t i = 0; i < ptrs.size(); ++i) {
|
||||
if (reshaped_masks.at(i)) {
|
||||
memcpy(reinterpret_cast<void *>(reshaped_ptrs.mutable_at(i)),
|
||||
reshaped_values.data(i), values.dtype().itemsize());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
init_triton_env_vars(subm);
|
||||
// init_triton_codegen(subm.def_submodule("code_gen"));
|
||||
init_triton_runtime(subm.def_submodule("runtime"));
|
||||
init_triton_ir(subm.def_submodule("ir"));
|
||||
init_triton_interpreter(subm.def_submodule("interpreter"));
|
||||
init_triton_translation(subm);
|
||||
}
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.interpreter.interpreter import program_ids_from_grid
|
||||
|
||||
|
||||
def test_addition():
|
||||
|
||||
@triton.jit(interpret=True)
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
a = torch.rand((128,), device="cuda")
|
||||
b = torch.rand((128,), device="cuda")
|
||||
expected = a + b
|
||||
output = torch.empty((128,), device="cuda")
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
|
||||
|
||||
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
|
||||
|
||||
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
def test_program_ids_from_grid():
|
||||
random.seed(123)
|
||||
grid = (3, 4)
|
||||
expected_combinations = 3 * 4
|
||||
unique_combinations = set(program_ids_from_grid(grid))
|
||||
assert len(unique_combinations) == expected_combinations
|
||||
|
||||
first_run = list(program_ids_from_grid(grid))
|
||||
second_run = list(program_ids_from_grid(grid))
|
||||
assert first_run != second_run
|
||||
|
||||
|
||||
def test_atomic():
|
||||
@triton.jit(interpret=True)
|
||||
def atomic(
|
||||
x_ptr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
tl.atomic_add(x_ptr + pid, 1)
|
||||
t = tl.atomic_xchg(x_ptr + pid, 3)
|
||||
t += 1 # 2
|
||||
tl.atomic_cas(x_ptr + pid, 3, t) # match
|
||||
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
|
||||
nb_dim = 16
|
||||
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
|
||||
|
||||
atomic[(nb_dim, )](a)
|
||||
assert torch.allclose(a, torch.full_like(a, 2))
|
||||
@@ -2421,7 +2421,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
if in_dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
@@ -5,10 +5,10 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
|
||||
(4, 48, 1024, 32),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 1024, 128)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
|
||||
(2, 4, 512, 32),
|
||||
(2, 4, 512, 64),
|
||||
(2, 4, 512, 128)])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [True, False])
|
||||
@pytest.mark.parametrize('seq_par', [True, False])
|
||||
@@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
pytest.skip('Segmentation fault')
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
|
||||
if not interpreter and capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExecutionContext:
|
||||
program_id: Tuple[int]
|
||||
program_size: Tuple[int]
|
||||
@@ -1,171 +0,0 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
from .. import language as tl
|
||||
# import .language.core as lcore
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
|
||||
debugger_constexpr)
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
tl_method_backup = {}
|
||||
|
||||
|
||||
def get_proxy_method(proxy, name):
|
||||
method = getattr(proxy, name)
|
||||
|
||||
def fun(*args, **kwarg):
|
||||
return method(*args, **kwarg)
|
||||
|
||||
return fun
|
||||
|
||||
|
||||
def attach_triton(module, proxy):
|
||||
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
|
||||
for name in method_list:
|
||||
if hasattr(module, name):
|
||||
attr = getattr(module, name)
|
||||
tl_method_backup[name] = attr
|
||||
if callable(attr):
|
||||
setattr(module, name, get_proxy_method(proxy, name))
|
||||
else:
|
||||
setattr(module, name, getattr(proxy, name))
|
||||
|
||||
|
||||
def detach_triton(module):
|
||||
for name, method in tl_method_backup.items():
|
||||
setattr(module, name, method)
|
||||
|
||||
|
||||
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
# reverse the grid dimensions and generate the range for each dimension
|
||||
reversed_grid = reversed(grid)
|
||||
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
|
||||
|
||||
# gen all combinations
|
||||
index_combinations = list(itertools.product(*ranges_for_each_dimension))
|
||||
random.shuffle(index_combinations)
|
||||
|
||||
for index_combination in index_combinations:
|
||||
yield index_combination
|
||||
|
||||
|
||||
class DebuggerFunction:
|
||||
def __init__(self, func, grid=(1,)):
|
||||
self.func = func
|
||||
self.grid = grid
|
||||
|
||||
def _is_constexpr(self, name):
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
|
||||
|
||||
def _get_constexpr(self):
|
||||
result = []
|
||||
for name, annotation in self.func.__annotations__.items():
|
||||
if annotation is lcore.constexpr:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
def _assert_constexpr(self, **kwargs):
|
||||
constexp = self._get_constexpr()
|
||||
missing = [i for i in constexp if i not in kwargs.keys()]
|
||||
assert len(missing) == 0, f"You must specify constexpr {missing}"
|
||||
|
||||
def _get_grid(self, **kwargs):
|
||||
if callable(self.grid):
|
||||
return self.grid(kwargs)
|
||||
else:
|
||||
return self.grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._assert_constexpr(**kwargs)
|
||||
|
||||
memory = MemoryMap()
|
||||
|
||||
def convert_arg(v):
|
||||
name, arg = v
|
||||
if torch.is_tensor(arg):
|
||||
ptr = memory.add_tensor(arg)
|
||||
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
|
||||
if self._is_constexpr(name):
|
||||
return debugger_constexpr(arg)
|
||||
return WrappedTensor(_primitive_to_tensor(arg))
|
||||
|
||||
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
|
||||
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
|
||||
|
||||
grid = self._get_grid(**kwargs)
|
||||
for program_id in program_ids_from_grid(grid):
|
||||
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
|
||||
attach_triton(tl, proxy)
|
||||
self.func(*new_args, **new_kwargs)
|
||||
detach_triton(tl)
|
||||
|
||||
|
||||
class GridSelector:
|
||||
"""
|
||||
Entry point of the debugger
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
version = torch.__version__
|
||||
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
|
||||
self.func = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return DebuggerFunction(self.func, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DebuggerFunction(self.func)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneGridSelector:
|
||||
def __init__(self, func, autotune_params):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return AutotuneRunner(self.func, self.autotune_params, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneRunner:
|
||||
def __init__(self, func, autotune_params, grid=None):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(self.autotune_params["configs"]) >= 1
|
||||
|
||||
for config in self.autotune_params["configs"][1:]:
|
||||
|
||||
def convert_arg(v):
|
||||
if torch.is_tensor(v):
|
||||
return torch.clone(v)
|
||||
return v
|
||||
|
||||
new_args = tuple(map(convert_arg, args))
|
||||
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
|
||||
if self.grid:
|
||||
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
|
||||
else:
|
||||
self.func(*new_args, **new_kwargs, **config.kwargs)
|
||||
|
||||
main_config = self.autotune_params["configs"][0]
|
||||
if self.grid:
|
||||
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
|
||||
else:
|
||||
self.func(*args, **kwargs, **main_config.kwargs)
|
||||
|
||||
|
||||
def triton_debug_autotune(**kwars):
|
||||
def wrapper(func):
|
||||
return AutotuneGridSelector(func, kwars)
|
||||
|
||||
return wrapper
|
||||
@@ -1,102 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from . import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegisteredStorage:
|
||||
storage: torch.Storage
|
||||
dtype: torch.dtype
|
||||
size: int
|
||||
ptr: int
|
||||
|
||||
@property
|
||||
def end_ptr(self) -> int:
|
||||
return self.ptr + self.size
|
||||
|
||||
@property
|
||||
def access_tensor(self) -> torch.Tensor:
|
||||
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
|
||||
|
||||
def ensure_immutable(self):
|
||||
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
|
||||
|
||||
|
||||
class MemoryMap:
|
||||
storages: [RegisteredStorage]
|
||||
|
||||
def __init__(self):
|
||||
self.storages = []
|
||||
|
||||
def _get_registered_storage(self, pointer: torch.Tensor):
|
||||
max_pointer = torch.max(pointer).item()
|
||||
min_pointer = torch.min(pointer).item()
|
||||
|
||||
registered_storage = next(
|
||||
filter(
|
||||
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
|
||||
),
|
||||
None,
|
||||
)
|
||||
if registered_storage is None:
|
||||
raise Exception("Storage not found or pointers spanning multiple tensors")
|
||||
registered_storage.ensure_immutable()
|
||||
return registered_storage
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
storage = t.untyped_storage()
|
||||
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
|
||||
return t.data_ptr()
|
||||
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
):
|
||||
assert pointer.is_cuda
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert mask.is_cuda
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
# Todo: The type is wrong here, we can't determine the correct type
|
||||
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
|
||||
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
|
||||
block[mask] = access_tensor[index_tensor[mask]]
|
||||
return block
|
||||
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
return
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
|
||||
@@ -1,641 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
def _primitive_to_tensor(x):
|
||||
"""
|
||||
Converts various Python primitive data types to PyTorch tensor.
|
||||
"""
|
||||
tensor_args = {"device": "cuda"}
|
||||
if isinstance(x, bool):
|
||||
return torch.tensor([x], dtype=torch.bool, **tensor_args)
|
||||
elif isinstance(x, int):
|
||||
if -(2**31) <= x < 2**31:
|
||||
return torch.tensor([x], dtype=torch.int32, **tensor_args)
|
||||
elif -(2**63) <= x < 2**63:
|
||||
return torch.tensor([x], dtype=torch.int64, **tensor_args)
|
||||
else:
|
||||
raise RuntimeError(f"Nonrepresentable integer {x}.")
|
||||
elif isinstance(x, float):
|
||||
return torch.tensor([x], dtype=torch.float32, **tensor_args)
|
||||
elif torch.is_tensor(x):
|
||||
return x
|
||||
elif isinstance(x, WrappedTensor):
|
||||
return x
|
||||
elif isinstance(x, debugger_constexpr):
|
||||
if x.value is None:
|
||||
return None
|
||||
return _primitive_to_tensor(x.value)
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f"cannot convert {x} of type {type(x)} to tensor"
|
||||
|
||||
|
||||
def _infer_tensor(func):
|
||||
"""
|
||||
A decorator function to harmonize function args:
|
||||
- converts primitives to PyTorch tensors
|
||||
- wraps PyTorch tensors with WrappedTensors
|
||||
"""
|
||||
def wrapper(*args):
|
||||
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
|
||||
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
|
||||
|
||||
return func(*new_args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _tensor_operation(func):
|
||||
"""
|
||||
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
|
||||
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert not torch.is_tensor(arg), "unexpected tensor argument"
|
||||
|
||||
def unwrap_tensor(v):
|
||||
if isinstance(v, WrappedTensor):
|
||||
return v.tensor
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
new_args = tuple(map(unwrap_tensor, args))
|
||||
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
|
||||
|
||||
result = func(args[0], *new_args[1:], **new_kwargs)
|
||||
return WrappedTensor(result) if torch.is_tensor(result) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class debugger_constexpr:
|
||||
def __init__(self, value):
|
||||
if isinstance(value, debugger_constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "debugger_constexpr(" + str(self.value) + ")"
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value > other
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value == other
|
||||
|
||||
def __or__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __ror__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __and__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def __rand__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if dtype in [torch.int64]:
|
||||
ret_ty = int
|
||||
elif dtype == torch.bool:
|
||||
ret_ty = bool
|
||||
elif dtype in [torch.float64]:
|
||||
ret_ty = float
|
||||
else:
|
||||
raise ValueError("dtype not supported in debugger")
|
||||
return debugger_constexpr(ret_ty(self.value))
|
||||
|
||||
|
||||
class WrappedTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.tensor.item()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "wrapped_" + str(self.tensor)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return torch.all(self.tensor == True).item() # noqa: E712
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensor.dtype
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __add__(self, other):
|
||||
return torch.add(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __sub__(self, other):
|
||||
return torch.sub(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rsub__(self, other):
|
||||
return torch.sub(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mul__(self, other):
|
||||
return torch.mul(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmul__(self, other):
|
||||
return self.__mul__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __truediv__(self, other):
|
||||
return torch.div(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rtruediv__(self, other):
|
||||
return torch.div(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __floordiv__(self, other):
|
||||
return torch.floor_divide(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rfloordiv__(self, other):
|
||||
return torch.floor_divide(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mod__(self, other):
|
||||
return torch.remainder(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmod__(self, other):
|
||||
return torch.remainder(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __neg__(self):
|
||||
return -self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __invert__(self):
|
||||
return ~self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __and__(self, other):
|
||||
return torch.bitwise_and(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __or__(self, other):
|
||||
return torch.bitwise_or(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __xor__(self, other):
|
||||
return torch.bitwise_xor(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lshift__(self, other):
|
||||
return torch.bitwise_left_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rshift__(self, other):
|
||||
return torch.bitwise_right_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __gt__(self, other):
|
||||
return self.tensor > other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rgt__(self, other):
|
||||
return other > self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ge__(self, other):
|
||||
return self.tensor >= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rge__(self, other):
|
||||
return other >= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lt__(self, other):
|
||||
return self.tensor < other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rlt__(self, other):
|
||||
return other < self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __le__(self, other):
|
||||
return self.tensor <= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rle__(self, other):
|
||||
return other <= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __eq__(self, other):
|
||||
return torch.equal(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ne__(self, other):
|
||||
return not torch.equal(self.tensor, other)
|
||||
|
||||
@_tensor_operation
|
||||
def __getitem__(self, slices):
|
||||
return self.tensor.__getitem__(slices)
|
||||
# if isinstance(slices, slice):
|
||||
# slices = [slices]
|
||||
# src_shape = self.shape
|
||||
# dst_shape = []
|
||||
# curr = 0
|
||||
# for sl in slices:
|
||||
# if isinstance(sl, constexpr) and sl.value is None:
|
||||
# dst_shape.append(1)
|
||||
# elif sl == slice(None, None, None):
|
||||
# dst_shape.append(src_shape[curr].value)
|
||||
# curr += 1
|
||||
# ret = torch.reshape(self.tensor, dst_shape, )
|
||||
# return ret
|
||||
|
||||
@_tensor_operation
|
||||
def to(self, dtype, bitcast=False):
|
||||
return self.tensor.to(dtype)
|
||||
# if isinstance(bitcast, constexpr):
|
||||
# bitcast = bitcast.value
|
||||
# if bitcast:
|
||||
# return semantic.bitcast(self, dtype, )
|
||||
# return semantic.cast(self, dtype, )
|
||||
|
||||
|
||||
def _constexpr_to_value(v):
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
|
||||
class TritonLangProxy:
|
||||
_memory_map: MemoryMap
|
||||
_context: ExecutionContext
|
||||
|
||||
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
|
||||
self._memory_map = memory_map
|
||||
self._context = context
|
||||
|
||||
# Types
|
||||
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
|
||||
|
||||
# constexpr = debugger_constexpr
|
||||
|
||||
# Program functions
|
||||
|
||||
@_tensor_operation
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
cache_modifier="",
|
||||
eviction_policy="",
|
||||
volatile=False,
|
||||
):
|
||||
return self._memory_map.load(pointer, mask, other)
|
||||
|
||||
@_tensor_operation
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
return self._memory_map.store(pointer, value, mask)
|
||||
|
||||
@_tensor_operation
|
||||
def program_id(self, axis):
|
||||
assert axis < len(self._context.program_id)
|
||||
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def num_programs(self, axis):
|
||||
assert axis < len(self._context.program_size)
|
||||
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def arange(self, start, end):
|
||||
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def zeros(self, shape, dtype):
|
||||
for i, d in enumerate(shape):
|
||||
if not isinstance(d, debugger_constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
shape = [x.value for x in shape]
|
||||
if isinstance(dtype, lcore.dtype):
|
||||
if dtype.is_fp32():
|
||||
dtype = torch.float32
|
||||
elif dtype.is_fp16():
|
||||
dtype = torch.float16
|
||||
elif dtype.is_bf16():
|
||||
dtype = torch.bfloat16
|
||||
elif dtype.is_int32():
|
||||
dtype = torch.int32
|
||||
elif dtype.is_int16():
|
||||
dtype = torch.int16
|
||||
elif dtype.is_int8():
|
||||
dtype = torch.int8
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype {dtype}")
|
||||
return torch.zeros(size=shape, dtype=dtype, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
|
||||
if dst_ty is None:
|
||||
dst_ty = torch.float16
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast(self, input, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast_to(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cat(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def reshape(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
|
||||
assert input.dtype == other.dtype
|
||||
if trans_a:
|
||||
input = input.T
|
||||
if trans_b:
|
||||
other = other.T
|
||||
return torch.matmul(input=input, other=other)
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_cas(self, pointer, cmp, val):
|
||||
stored = self._memory_map.load(pointer, None, 0.0)
|
||||
if not isinstance(cmp, torch.Tensor):
|
||||
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
|
||||
if not isinstance(val, torch.Tensor):
|
||||
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
|
||||
if stored == cmp:
|
||||
self._memory_map.store(pointer, val, None)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xchg(self, pointer, val, mask=None):
|
||||
if isinstance(val, int):
|
||||
val = torch.tensor([val], dtype=torch.int32, device="cuda")
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
self._memory_map.store(pointer, val, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_add(self, pointer, val, mask=None):
|
||||
# arbitrary other value as it will masked during storing
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = stored + val
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_max(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.maximum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_min(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.minimum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_and(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_and(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_or(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_or(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xor(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_xor(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def where(self, condition, x, y):
|
||||
condition = _primitive_to_tensor(condition)
|
||||
x = _primitive_to_tensor(x)
|
||||
y = _primitive_to_tensor(y)
|
||||
return torch.where(condition, x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def umulhi(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def fdiv(self, x, y, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def exp(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
@_tensor_operation
|
||||
def log(self, x):
|
||||
return torch.log(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cos(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sin(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sqrt(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
@_tensor_operation
|
||||
def globaltimer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def clock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def debug_barrier(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def multiple_of(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_contiguous(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_constancy(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def abs(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cdiv(self, x, div):
|
||||
return (x + div - 1) // div
|
||||
|
||||
@_tensor_operation
|
||||
def minimum(self, x, y):
|
||||
if isinstance(x, int):
|
||||
x = torch.tensor(x, device="cuda")
|
||||
if isinstance(y, int):
|
||||
y = torch.tensor(y, device="cuda")
|
||||
return torch.minimum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def maximum(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def sigmoid(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def softmax(self, x, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def ravel(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def swizzle2d(self, i, j, size_i, size_j, size_g):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def zeros_like(self, input):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def max(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.max(input)
|
||||
return torch.max(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmax(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def min(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.min(input)
|
||||
return torch.min(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmin(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def sum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.sum(input)
|
||||
return torch.sum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def xor_sum(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cumsum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.cumsum(input)
|
||||
return torch.cumsum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def cumprod(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.cumprod(input)
|
||||
return torch.cumprod(input, dim=axis)
|
||||
@@ -1,18 +0,0 @@
|
||||
try:
|
||||
import torch as _torch
|
||||
except ImportError:
|
||||
_torch = None
|
||||
|
||||
|
||||
class TorchWrapper:
|
||||
"""
|
||||
Helps in making torch an optional dependency
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _torch is None:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return getattr(_torch, name)
|
||||
|
||||
|
||||
torch = TorchWrapper()
|
||||
@@ -542,14 +542,15 @@ class tensor:
|
||||
self.shape = [constexpr(s) for s in self.shape]
|
||||
|
||||
def __str__(self) -> str:
|
||||
# ex. "float32[3,4]"
|
||||
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
|
||||
# ex. "float32[16, 32]"
|
||||
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
|
||||
|
||||
@builtin
|
||||
def __add__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.add(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __radd__(self, other, _builder=None):
|
||||
return self.__add__(other, _builder=_builder)
|
||||
|
||||
@@ -558,6 +559,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.sub(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rsub__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.sub(other, self, _builder)
|
||||
@@ -567,6 +569,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.mul(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rmul__(self, other, _builder=None):
|
||||
return self.__mul__(other, _builder=_builder)
|
||||
|
||||
@@ -575,6 +578,7 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.truediv(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rtruediv__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.truediv(other, self, _builder)
|
||||
@@ -666,8 +670,6 @@ class tensor:
|
||||
else:
|
||||
return semantic.lshr(other, self, _builder)
|
||||
|
||||
# comparison operators
|
||||
|
||||
# >
|
||||
@builtin
|
||||
def __gt__(self, other, _builder=None):
|
||||
@@ -745,7 +747,7 @@ class tensor:
|
||||
slices = [slices]
|
||||
ret = self
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
if sl is None or isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
|
||||
pass
|
||||
@@ -830,6 +832,8 @@ def arange(start, end, _builder=None):
|
||||
def _shape_check_impl(shape):
|
||||
shape = _constexpr_to_value(shape)
|
||||
for i, d in enumerate(shape):
|
||||
if isinstance(d, int):
|
||||
d = constexpr(d)
|
||||
if not isinstance(d, constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
|
||||
@@ -1570,6 +1570,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno:
|
||||
|
||||
|
||||
def _convert_elem_to_ir_value(builder, elem, require_i64):
|
||||
if isinstance(elem, int):
|
||||
elem = tl.constexpr(elem)
|
||||
if isinstance(elem, tl.constexpr):
|
||||
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
|
||||
elif isinstance(elem, tl.tensor):
|
||||
|
||||
@@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
|
||||
else:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
|
||||
if core.constexpr(input.dtype.is_floating()):
|
||||
input = input.to(core.float32)
|
||||
else:
|
||||
|
||||
525
python/triton/runtime/interpreter.py
Normal file
525
python/triton/runtime/interpreter.py
Normal file
@@ -0,0 +1,525 @@
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .._C.libtriton.triton import interpreter as _interpreter
|
||||
|
||||
|
||||
# TODO: duplicate
|
||||
def str_to_ty(name):
|
||||
language = tl
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
return language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8e4nv": language.float8e4nv,
|
||||
"fp8e5": language.float8e5,
|
||||
"fp8e4b15": language.float8e4b15,
|
||||
"fp8e4b15x4": language.float8e4b15x4,
|
||||
"fp16": language.float16,
|
||||
"bf16": language.bfloat16,
|
||||
"fp32": language.float32,
|
||||
"fp64": language.float64,
|
||||
"i1": language.int1,
|
||||
"i8": language.int8,
|
||||
"i16": language.int16,
|
||||
"i32": language.int32,
|
||||
"i64": language.int64,
|
||||
"u8": language.uint8,
|
||||
"u16": language.uint16,
|
||||
"u32": language.uint32,
|
||||
"u64": language.uint64,
|
||||
"B": language.int1,
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
|
||||
class TensorHandle:
|
||||
|
||||
def __init__(self, data, dtype):
|
||||
self.data = data
|
||||
self.dtype = dtype
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.data.all())
|
||||
|
||||
|
||||
class BlockPointerHandle:
|
||||
|
||||
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
self.base = base
|
||||
self.shape = shape
|
||||
self.strides = strides
|
||||
self.offsets = offsets
|
||||
self.tensor_shape = tensor_shape
|
||||
self.order = order
|
||||
|
||||
def materialize_pointers(self, boundary_check):
|
||||
dtype_tt = self.base.dtype.element_ty
|
||||
n_bytes = dtype_tt.primitive_bitwidth // 8
|
||||
tensor_shape = self.tensor_shape
|
||||
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
|
||||
masks = np.ones(self.tensor_shape, dtype=bool)
|
||||
for dim in range(len(tensor_shape)):
|
||||
bcast_dims = [1] * len(tensor_shape)
|
||||
bcast_dims[dim] = tensor_shape[dim]
|
||||
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
|
||||
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
|
||||
if dim in boundary_check:
|
||||
masks = np.logical_and(masks, off < self.shape[dim].data)
|
||||
ptrs = TensorHandle(ptrs, self.base.dtype)
|
||||
return ptrs, masks
|
||||
|
||||
|
||||
def wrap_ret(compute_ret_ty):
|
||||
def wrapper(fn):
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = fn(*args, **kwargs)
|
||||
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
class Builder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.arch = None
|
||||
# pass
|
||||
|
||||
def set_grid_idx(self, x, y, z):
|
||||
assert x < self.grid_dim[0]
|
||||
assert y < self.grid_dim[1]
|
||||
assert z < self.grid_dim[2]
|
||||
self.grid_idx = (x, y, z)
|
||||
|
||||
def set_grid_dim(self, nx, ny, nz):
|
||||
self.grid_dim = (nx, ny, nz)
|
||||
|
||||
def np_dtype(self, tt_dtype):
|
||||
if isinstance(tt_dtype, tl.pointer_type):
|
||||
return np.dtype(np.uint64)
|
||||
np_types = {
|
||||
tl.float16: np.dtype(np.float16),
|
||||
tl.float32: np.dtype(np.float32),
|
||||
tl.float64: np.dtype(np.float64),
|
||||
tl.int8: np.dtype(np.int8),
|
||||
tl.uint8: np.dtype(np.uint8),
|
||||
tl.int16: np.dtype(np.int16),
|
||||
tl.uint16: np.dtype(np.uint16),
|
||||
tl.int32: np.dtype(np.int32),
|
||||
tl.uint32: np.dtype(np.uint32),
|
||||
tl.int64: np.dtype(np.int64),
|
||||
tl.uint64: np.dtype(np.uint64),
|
||||
}
|
||||
return np_types[tt_dtype]
|
||||
|
||||
# constants
|
||||
def get_half_ty(self):
|
||||
return tl.float16
|
||||
|
||||
def get_float_ty(self):
|
||||
return tl.float32
|
||||
|
||||
def get_int64_ty(self):
|
||||
return tl.int64
|
||||
|
||||
def get_ptr_ty(self, elt_ty, addr_space):
|
||||
return tl.pointer_type(elt_ty, addr_space)
|
||||
|
||||
def get_block_ty(self, dtype, shape):
|
||||
return tl.tensor(shape, dtype)
|
||||
|
||||
def get_int32(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
|
||||
|
||||
def get_int64(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
|
||||
|
||||
def get_fp16(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
|
||||
|
||||
def get_fp32(self, value):
|
||||
return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
|
||||
|
||||
def get_null_value(self, type):
|
||||
return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type)
|
||||
|
||||
# programming model
|
||||
def create_get_program_id(self, axis):
|
||||
assert self.grid_idx is not None
|
||||
return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
|
||||
|
||||
def create_get_num_programs(self, axis):
|
||||
return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
|
||||
|
||||
# memory ops
|
||||
def create_load(self, ptr, _0, _1, is_volatile):
|
||||
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
||||
other = None
|
||||
return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
|
||||
|
||||
def create_store(self, ptr, val, _0, _1):
|
||||
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
|
||||
return self.create_masked_store(ptr, val, mask, None, None)
|
||||
|
||||
def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
|
||||
dtype_tt = ptrs.dtype.element_ty
|
||||
dtype_np = self.np_dtype(dtype_tt)
|
||||
if other is None:
|
||||
other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
||||
ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
|
||||
return TensorHandle(ret, dtype_tt)
|
||||
|
||||
def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
|
||||
return _interpreter.store(ptrs.data, value.data, mask.data)
|
||||
|
||||
# casting ops
|
||||
def cast_impl(self, src, dst_type):
|
||||
if isinstance(dst_type, tl.tensor):
|
||||
dst_type = dst_type.dtype
|
||||
return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type)
|
||||
|
||||
create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
|
||||
create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
|
||||
|
||||
def create_fp_to_fp(self, src, dst_type):
|
||||
assert "float8 not NotImplemented yet"
|
||||
|
||||
def create_bitcast(self, src, dst_type):
|
||||
return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type)
|
||||
|
||||
# binary operators
|
||||
def binary_op(self, lhs, rhs, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data), lhs.dtype)
|
||||
|
||||
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
||||
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
||||
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
||||
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
||||
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
||||
create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
|
||||
create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
|
||||
create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
|
||||
create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
||||
create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
||||
create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
|
||||
create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
|
||||
create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
|
||||
create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
|
||||
create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
|
||||
create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
|
||||
create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
|
||||
create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
|
||||
create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
|
||||
create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
|
||||
create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
|
||||
create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
|
||||
create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
|
||||
create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
|
||||
|
||||
# ternary functions
|
||||
def ternary_op(self, lhs, rhs, other, op):
|
||||
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
|
||||
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
|
||||
|
||||
# unary functions
|
||||
def unary_op(self, arg, op):
|
||||
return TensorHandle(op(arg.data), arg.dtype)
|
||||
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
|
||||
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
|
||||
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
|
||||
create_log = lambda self, arg: self.unary_op(arg, np.log)
|
||||
create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
|
||||
create_fabs = lambda self, arg: self.unary_op(arg, np.abs)
|
||||
create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
|
||||
|
||||
# tensor operators
|
||||
create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot)
|
||||
create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype)
|
||||
create_trans = lambda self, arg: self.unary_op(arg, np.transpose)
|
||||
|
||||
def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc):
|
||||
return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype)
|
||||
|
||||
def create_make_range(self, start, stop):
|
||||
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
|
||||
|
||||
# pointer arithmetic
|
||||
|
||||
def create_addptr(self, ptr, offset):
|
||||
dtype_tt = ptr.dtype.element_ty
|
||||
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
|
||||
|
||||
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
assert padding_option is None
|
||||
other = None
|
||||
return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
|
||||
|
||||
def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
|
||||
ptrs, masks = ptr.materialize_pointers(boundary_check)
|
||||
return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
|
||||
|
||||
def create_expand_dims(self, arg, axis):
|
||||
return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype)
|
||||
|
||||
def create_broadcast(self, arg, shape):
|
||||
return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype)
|
||||
|
||||
def create_int_to_ptr(self, val, dst_ty):
|
||||
return TensorHandle(val.data.astype(np.uint64), dst_ty)
|
||||
# def create_cat(self, lhs, rhs):
|
||||
# pass
|
||||
|
||||
# def create_broadcast(self, arg, shape):
|
||||
# pass
|
||||
|
||||
def create_splat(self, arg, shape):
|
||||
return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype)
|
||||
|
||||
# def create_atomic_cas(self, ptr, cmp, val, sem):
|
||||
# pass
|
||||
|
||||
# def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem):
|
||||
# pass
|
||||
|
||||
# def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
|
||||
# pass
|
||||
|
||||
# def create_reduce(self, operands, axis):
|
||||
# pass
|
||||
|
||||
# def create_reduce_ret(self, args):
|
||||
# pass
|
||||
|
||||
# def create_scan(self, operands, axis):
|
||||
# pass
|
||||
|
||||
# def create_scan_ret(self, args):
|
||||
# pass
|
||||
|
||||
# def create_ptr_to_int(self, val, type):
|
||||
# pass
|
||||
|
||||
# def create_int_to_ptr(self, val, type):
|
||||
# pass
|
||||
|
||||
# def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
|
||||
# pass
|
||||
|
||||
# def create_print(self, prefix, values):
|
||||
# pass
|
||||
|
||||
# def create_assert(self, condition, message, fileName, funcName, lineNo):
|
||||
# pass
|
||||
|
||||
# def create_undef(self, type):
|
||||
# pass
|
||||
|
||||
# def create_barrier(self):
|
||||
# pass
|
||||
|
||||
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
|
||||
return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order)
|
||||
|
||||
def create_advance(self, ptr, offsets):
|
||||
assert len(ptr.offsets) == len(offsets)
|
||||
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order)
|
||||
for i in range(len(offsets)):
|
||||
ret.offsets[i].data += offsets[i].data
|
||||
return ret
|
||||
|
||||
|
||||
def patch_attr(obj, name, member, builder):
|
||||
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
|
||||
setattr(obj, name, new_member)
|
||||
|
||||
|
||||
def _patch_lang_tensor(tensor, builder):
|
||||
for name, member in inspect.getmembers(tensor):
|
||||
if tl.core.is_builtin(member):
|
||||
patch_attr(tensor, name, member, builder)
|
||||
tensor.__index__ = lambda self: int(self.handle.data)
|
||||
tensor.__bool__ = lambda self: True
|
||||
|
||||
|
||||
def _patch_lang_core(lang, builder):
|
||||
for name, member in inspect.getmembers(lang):
|
||||
if tl.core.is_builtin(member):
|
||||
patch_attr(lang, name, member, builder)
|
||||
# reduce is better off with a separate patch due to how
|
||||
# the builder currently interfaces with custom functions
|
||||
|
||||
def _new_reduce(input, axis, combine_fn):
|
||||
fn = combine_fn.fn.__name__
|
||||
mapping = {
|
||||
'maximum': np.max,
|
||||
'_sum_combine': np.sum,
|
||||
}
|
||||
ret = mapping[fn](input.handle.data, axis=axis)
|
||||
ret_type = tl.block_type(input.dtype, ret.shape)
|
||||
return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type)
|
||||
|
||||
lang.reduce = _new_reduce
|
||||
|
||||
|
||||
def _patch_lang_math(lang, builder):
|
||||
math = lang.math
|
||||
mapping = {
|
||||
'abs': 'abs',
|
||||
'acos': 'arccos',
|
||||
'asin': 'arcsin',
|
||||
'exp2': 'exp2',
|
||||
'log2': 'log2',
|
||||
'max': 'maximum',
|
||||
}
|
||||
|
||||
def make_numpy(name):
|
||||
def impl(*args, **kwargs):
|
||||
ret_type = args[0].type # TODO: incorrect
|
||||
ret_dtype = args[0].dtype # TODO: incorrect
|
||||
args = [arg.handle.data for arg in args]
|
||||
kwargs = {k: v.handle.data for k, v in kwargs.items()}
|
||||
ret = getattr(np, mapping[name])(*args, **kwargs)
|
||||
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
|
||||
return ret
|
||||
return impl
|
||||
|
||||
def make_fallback(name):
|
||||
def fallback(*args, **kwargs):
|
||||
raise NotImplementedError(f"""
|
||||
{name} not supported in interpreter mode: no known numpy implementation.
|
||||
If you think that {name} in fact does have a numpy implementation, please add it
|
||||
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
|
||||
""")
|
||||
return fallback
|
||||
|
||||
for name, member in inspect.getmembers(math):
|
||||
if name in mapping:
|
||||
setattr(math, name, make_numpy(name))
|
||||
else:
|
||||
setattr(math, name, make_fallback(name))
|
||||
|
||||
|
||||
# TODO: wrap everything in triton tensors
|
||||
def _implicit_cvt(arg):
|
||||
if isinstance(arg, int):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
|
||||
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
||||
return tl.tensor(handle, ty)
|
||||
return arg
|
||||
|
||||
|
||||
def _unwrap(tensor):
|
||||
if isinstance(tensor, triton.TensorWrapper):
|
||||
return tensor.base
|
||||
return tensor
|
||||
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
def __init__(self, fn, arg_names, grid):
|
||||
from .jit import _normalize_ty # TODO: modularize
|
||||
self.fn = fn
|
||||
self.arg_names = arg_names
|
||||
self.grid = grid
|
||||
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
_patch_lang_math(lang[0], builder)
|
||||
|
||||
def __call__(self, *args_dev, **kwargs):
|
||||
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
|
||||
# removes reserved keywords from kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
||||
# remaps core language functions to interpreted ones
|
||||
self._patch_lang(builder)
|
||||
# we need to copy arguments to the host for the interpreter
|
||||
# implicitly convert tensor arguments to their base pointers
|
||||
args = inspect.getcallargs(self.fn, *args_hst, **kwargs)
|
||||
args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
|
||||
# iterate through grid
|
||||
grid = self.grid(args) if callable(self.grid) else self.grid
|
||||
assert len(grid) <= 3
|
||||
grid = grid + (1,) * (3 - len(grid))
|
||||
builder.set_grid_dim(*grid)
|
||||
for x in range(grid[0]):
|
||||
for y in range(grid[1]):
|
||||
for z in range(grid[2]):
|
||||
builder.set_grid_idx(x, y, z)
|
||||
self.fn(**args)
|
||||
# copy arguments back to propagate side-effects
|
||||
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
||||
if hasattr(arg_dev, 'data_ptr'):
|
||||
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
|
||||
|
||||
|
||||
class InterpretedFunction:
|
||||
|
||||
def _patch_lang(self, builder):
|
||||
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
|
||||
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
|
||||
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
|
||||
_patch_lang_core(lang[0], builder)
|
||||
|
||||
def __init__(self, fn) -> None:
|
||||
self.fn = fn
|
||||
|
||||
def run(*args, **kwargs):
|
||||
grid = kwargs['grid']
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
|
||||
|
||||
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
|
||||
self.run = run
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return GridExecutor(self.fn, self.arg_names, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._patch_lang(builder)
|
||||
return self.fn(*args, **kwargs)
|
||||
@@ -14,6 +14,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
@@ -270,10 +271,6 @@ class JITFunction(KernelInterface[T]):
|
||||
tys[v] = v
|
||||
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
@@ -568,7 +565,6 @@ def jit(
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
interpret: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
@@ -590,9 +586,8 @@ def jit(
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
if interpret:
|
||||
from ..interpreter.interpreter import GridSelector
|
||||
return GridSelector(fn)
|
||||
if os.getenv("TRITON_INTERPRET", "0") == "1":
|
||||
return InterpretedFunction(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
fn,
|
||||
|
||||
Reference in New Issue
Block a user