[FRONTEND] cleaned up language; added frontend function for globaltimer special register (#1525)

This commit is contained in:
Philippe Tillet
2023-04-14 15:29:27 -07:00
committed by GitHub
parent 0d76c4ca95
commit e5c7d2a83c
21 changed files with 1299 additions and 1194 deletions

View File

@@ -420,14 +420,13 @@ def TT_ReduceReturnOp: TT_Op<"reduce.return",
//
// External Elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
TT_Op<mnemonic,
traits # [SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
@@ -435,7 +434,17 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAnd
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
}
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
}
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>]> {
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
}
//

View File

@@ -734,20 +734,20 @@ struct CmpFOpConversion
}
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion> {
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion>;
template <class T>
struct ExternElementwiseOpConversion
: public ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>> {
using Base = ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
typedef typename Base::OpAdaptor OpAdaptor;
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
Value createDestOp(T op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
StringRef funcName = op.getSymbol();
if (funcName.empty())
llvm::errs() << "ExtElemwiseOpConversion";
llvm::errs() << "ExternElementwiseOpConversion";
Type funcType = getFunctionType(elemTy, operands);
LLVM::LLVMFuncOp funcOp =
@@ -761,8 +761,7 @@ private:
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
triton::ExtElemwiseOp op,
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op,
StringRef funcName, Type funcType) const {
using LLVM::LLVMFuncOp;
@@ -771,7 +770,8 @@ private:
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
auto parent = ((Operation *)op)->getParentOfType<mlir::LLVM::LLVMFuncOp>();
mlir::OpBuilder b(parent);
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
ret.getOperation()->setAttr(
"libname", StringAttr::get(op->getContext(), op.getLibname()));
@@ -1117,7 +1117,11 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
patterns.add<ExternElementwiseOpConversion<triton::PureExternElementwiseOp>>(
typeConverter, benefit);
patterns
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call

View File

@@ -411,14 +411,16 @@ struct TritonAtomicRMWPattern
}
};
struct TritonExtElemwisePattern
: public OpConversionPattern<triton::ExtElemwiseOp> {
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
template <class T>
struct TritonExternElementwisePattern : public OpConversionPattern<T> {
using OpConversionPattern<T>::OpConversionPattern;
using OpConversionPattern<T>::typeConverter;
typedef typename OpConversionPattern<T>::OpAdaptor OpAdaptor;
LogicalResult
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
matchAndRewrite(T op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
addNamedAttrs(rewriter.replaceOpWithNewOp<T>(
op, typeConverter->convertType(op.getType()),
adaptor.getArgs(), adaptor.getLibname(),
adaptor.getLibpath(), adaptor.getSymbol()),
@@ -539,7 +541,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
TritonLoadPattern, TritonStorePattern,
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
typeConverter, context);
}

View File

@@ -40,6 +40,9 @@ OpTrait::impl::verifySameOperandsEncoding(Operation *op,
LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding(
Operation *op, bool allowTensorPointerType) {
if (op->getNumOperands() == 0)
return success();
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();

View File

@@ -234,9 +234,10 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
if (!isROCM) {
if (name == "libdevice") {
linkLibdevice(module);
} else {
assert(false && "unknown extern lib: ");
}
// else {
// assert(false && "unknown extern lib: ");
// }
}
return false;

View File

@@ -1,3 +1,4 @@
graft src
graft triton/third_party
graft triton/runtime/backends/
graft triton/language/extra

View File

@@ -224,6 +224,7 @@ setup(
"triton/common",
"triton/compiler",
"triton/language",
"triton/language/extra",
"triton/ops",
"triton/ops/blocksparse",
"triton/runtime",

12
python/src/extra/cuda.ll Normal file
View File

@@ -0,0 +1,12 @@
; ~/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/bin/llvm-as ./src/extra/cuda.ll -o ./triton/language/extra/cuda.bc
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "nvptx64-nvidia-cuda"
define i64 @globaltimer() #0 {
%1 = call i64 asm sideeffect "mov.u64 $0, %globaltimer;", "=l"() nounwind
ret i64 %1
}
attributes #0 = { alwaysinline nounwind }

View File

@@ -1267,14 +1267,18 @@ void init_triton_ir(py::module &&m) {
ptr, val, mask);
})
// External
.def("create_external_elementwise",
.def("create_extern_elementwise",
[](mlir::OpBuilder &self, const std::string &libName,
const std::string &libPath, const std::string &symbol,
std::vector<mlir::Value> &argList,
mlir::Type retType) -> mlir::Value {
std::vector<mlir::Value> &argList, mlir::Type retType,
bool isPure) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::ExtElemwiseOp>(
loc, retType, argList, libName, libPath, symbol);
if (isPure)
return self.create<mlir::triton::PureExternElementwiseOp>(
loc, retType, argList, libName, libPath, symbol);
else
return self.create<mlir::triton::ImpureExternElementwiseOp>(
loc, retType, argList, libName, libPath, symbol);
})
// Built-in instruction
.def("create_get_program_id",

View File

@@ -2312,12 +2312,35 @@ def test_while():
# print(m[0])
# print(n[0])
# -----------------------
# test extra
# -----------------------
def test_globaltimer():
@triton.jit
def kernel(Out1, Out2):
start = tl.extra.cuda.globaltimer()
off = tl.arange(0, 128)
for i in range(100):
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
end = tl.extra.cuda.globaltimer()
tl.store(Out2, end - start)
out1 = to_triton(np.zeros((128,), dtype=np.int64), device='cuda')
out2 = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
h = kernel[(1,)](out1, out2)
assert out2[0] > 0
# 2 inlined globaltimers + one extra in the wrapper extern function
assert h.asm["ptx"].count("%globaltimer") == 3
# -----------------------
# test layout conversions
# -----------------------
# TODO: backend should be tested separately
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),

View File

@@ -2,6 +2,16 @@
# Import order is significant here.
from . import math
from . import extra
from .standard import (
cdiv,
sigmoid,
softmax,
ravel,
swizzle2d,
zeros,
zeros_like,
)
from .core import (
abs,
advance,
@@ -21,7 +31,6 @@ from .core import (
broadcast,
broadcast_to,
cat,
cdiv,
constexpr,
cos,
debug_barrier,
@@ -56,18 +65,14 @@ from .core import (
pi32_t,
pointer_type,
program_id,
ravel,
reduce,
reshape,
sigmoid,
sin,
softmax,
sqrt,
static_assert,
static_print,
store,
sum,
swizzle2d,
static_range,
tensor,
trans,
@@ -81,8 +86,6 @@ from .core import (
void,
where,
xor_sum,
zeros,
zeros_like,
)
from .random import (
pair_uniform_to_normal,
@@ -127,6 +130,7 @@ __all__ = [
"dot",
"dtype",
"exp",
"extra",
"fdiv",
"float16",
"float32",

View File

@@ -994,10 +994,36 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
@builtin
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
"""
Returns a pointer to a block in a parent tensor
:param base: The base pointer to the parent tensor
:param shape: The shape of the parent tensor
:param strides: The strides of the parent tensor
:param offsets: The offsets to the block
:param block_shape: The shape of the block
:param order: The order of the original data format
"""
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
@builtin
def advance(base: tensor, offsets, _builder=None):
"""
Advance a block pointer
:param base: the block pointer to advance
:param offsets: the offsets to advance, a tuple by dimension
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
def _decorator(func: T) -> T:
@@ -1080,7 +1106,6 @@ def atomic_xor(pointer, val, mask=None, _builder=None):
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, _builder=None):
"""
@@ -1266,6 +1291,32 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
return rindices
@triton.jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return where(x < y, x, y)
@triton.jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return where(x > y, x, y)
@triton.jit
def _max_combine(a, b):
return maximum(a, b)
@@ -1395,127 +1446,6 @@ def max_contiguous(input, values, _builder=None):
values = [x.value for x in values]
return semantic.max_contiguous(input, values)
# -----------------------
# Standard library
# -----------------------
@triton.jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@triton.jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return triton.language.where(x < y, x, y)
@triton.jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
:param input: the first input tensor
:type input: Block
:param other: the second input tensor
:type other: Block
"""
return triton.language.where(x > y, x, y)
@triton.jit
@_add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + triton.language.exp(-x))
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return fdiv(num, den, ieee_rounding)
@triton.jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`.
:param x: the input tensor
:type x: Block
"""
return triton.language.view(x, [x.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@triton.jit
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return full(shape, 0, dtype)
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
# -----------------------
# Debugging functions
# -----------------------
@@ -1568,32 +1498,6 @@ def device_assert(cond, msg="", _builder=None):
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
@builtin
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
"""
Returns a pointer to a block in a parent tensor
:param base: The base pointer to the parent tensor
:param shape: The shape of the parent tensor
:param strides: The strides of the parent tensor
:param offsets: The offsets to the block
:param block_shape: The shape of the block
:param order: The order of the original data format
"""
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
@builtin
def advance(base: tensor, offsets, _builder=None):
"""
Advance a block pointer
:param base: the block pointer to advance
:param offsets: the offsets to advance, a tuple by dimension
"""
return semantic.advance(base, offsets, _builder)
# -----------------------
# Iterators
# -----------------------
@@ -1623,3 +1527,86 @@ class static_range:
def __next__(self):
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
# -----------------------
# Extern functions
# -----------------------
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _builder: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")
num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")
arg_types = []
arg_list = []
for arg in args:
if isinstance(arg, tensor):
arg_types.append(arg.dtype)
arg_list.append(arg.handle)
else:
arg_types.append(type(arg))
arg_list.append(arg)
arg_types = tuple(arg_types)
if arg_types not in arg_type_symbol_dict:
raise ValueError(f"input arg type does not match."
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_type = arg_type_symbol_dict[arg_types][1]
if ret_shape:
ret_type = block_type(ret_type, ret_shape)
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
all_scalar = True
ret_shape = None
for i in range(len(dispatch_args)):
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
if dispatch_args[i].type.is_block():
all_scalar = False
if not all_scalar:
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
def extern(fn):
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -1,82 +0,0 @@
from __future__ import annotations # remove after python 3.11
from . import core, semantic
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _builder: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")
num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")
arg_types = []
arg_list = []
for arg in args:
if isinstance(arg, core.tensor):
arg_types.append(arg.dtype)
arg_list.append(arg.handle)
else:
arg_types.append(type(arg))
arg_list.append(arg)
arg_types = tuple(arg_types)
if arg_types not in arg_type_symbol_dict:
raise ValueError(f"input arg type does not match."
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_type = arg_type_symbol_dict[arg_types][1]
if ret_shape:
ret_type = core.block_type(ret_type, ret_shape)
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
def extern(fn):
"""A decorator for external functions."""
return core.builtin(fn)
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
all_scalar = True
ret_shape = None
for i in range(len(dispatch_args)):
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
if dispatch_args[i].type.is_block():
all_scalar = False
if not all_scalar:
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)

View File

@@ -0,0 +1,3 @@
from . import cuda
__all__ = ['cuda']

Binary file not shown.

View File

@@ -0,0 +1,12 @@
import os
from .. import core
__path__ = os.path.dirname(os.path.abspath(__file__))
@core.extern
def globaltimer(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("globaltimer", core.dtype("int64")),
}, is_pure=False, _builder=_builder)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from ..runtime.jit import jit
from . import core
# -----------------------
# Standard library
# -----------------------
@jit
def cdiv(x, div):
"""
Computes the ceiling division of :code:`x` by :code:`div`
:param x: the input number
:type input: Block
:param div: the divisor
:param div: Block
"""
return (x + div - 1) // div
@jit
@core._add_math_1arg_docstr("sigmoid")
def sigmoid(x):
return 1 / (1 + core.exp(-x))
@jit
@core._add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
z = x - core.max(x, 0)
num = core.exp(z)
den = core.sum(num, 0)
return core.fdiv(num, den, ieee_rounding)
@jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`.
:param x: the input tensor
:type x: Block
"""
return core.view(x, [x.numel])
@jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = core.minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@jit
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return core.full(shape, 0, dtype)
@jit
def zeros_like(input):
return zeros(input.shape, input.dtype)

View File

@@ -61,6 +61,9 @@ class CudaDriver(DriverBase):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
def get_extern_path(self):
return os.path.join(self.third_party_dir(), "cuda", "lib")
def get_libdevice_path(self):
return os.path.join(self.third_party_dir(), "cuda", "lib", "libdevice.10.bc")
@@ -68,6 +71,7 @@ class CudaDriver(DriverBase):
self.utils = CudaUtils()
self.backend = self.CUDA
self.libdevice_path = self.get_libdevice_path()
self.extern_path = self.get_extern_path()
# -----------------------------
# HIP

View File

@@ -0,0 +1,15 @@
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))

View File

@@ -156,6 +156,7 @@ class Libdevice(ExternLibrary):
'''
super().__init__("libdevice", path)
self._symbol_groups = {}
self.is_pure = True
@staticmethod
def _extract_symbol(line) -> Optional[Symbol]:
@@ -287,7 +288,7 @@ class Libdevice(ExternLibrary):
# def <op_name>(<args>, _builder=None):
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str = "from . import core\n"
import_str += "from ..runtime import driver\n"
import_str += "import os\n"
@@ -295,13 +296,13 @@ class Libdevice(ExternLibrary):
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", driver.libdevice_path)\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
func_str += "@core.extern\n"
func_name_str = f"def {symbols[0].op_name}("
for arg_name in symbols[0].arg_names:
func_name_str += f"{arg_name}, "
func_name_str += "_builder=None):\n"
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
return_str = f"\treturn core.extern_elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
for arg_name in symbols[0].arg_names:
return_str += f"{arg_name}, "
return_str += "], \n"
@@ -316,7 +317,8 @@ class Libdevice(ExternLibrary):
arg_type_symbol_dict_str += "}"
return_str += arg_type_symbol_dict_str
return_str += ", _builder)\n"
return_str += f", is_pure={self.is_pure}"
return_str += ", _builder=_builder)\n"
func_str += func_name_str + return_str + "\n"
file_str = import_str + header_str + func_str