mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03162023
This commit is contained in:
@@ -249,9 +249,12 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
else()
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
# TODO: Figure out which target is sufficient to fix errors; triton is
|
||||
# apparently not enough
|
||||
link_libraries(stdc++fs)
|
||||
endif()
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
@@ -112,7 +112,8 @@ private:
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// divides the first element of all groups of
|
||||
// _contiguity_ values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
@@ -123,6 +124,10 @@ private:
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
// On the other hand:
|
||||
// [0, 1, 2, 0, 4, 5, 6, 7]
|
||||
// would have divisibility 1 because
|
||||
// _contiguity_=1
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
|
||||
@@ -164,16 +164,16 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ConstantOpAxisInfoVisitor final
|
||||
: public AxisInfoVisitorImpl<arith::ConstantOp> {
|
||||
template <typename OpTy>
|
||||
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
||||
public:
|
||||
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
|
||||
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
||||
|
||||
AxisInfo
|
||||
getAxisInfo(arith::ConstantOp op,
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
|
||||
auto intAttr = op.getValue().template dyn_cast<IntegerAttr>();
|
||||
auto boolAttr = op.getValue().template dyn_cast<BoolAttr>();
|
||||
if (intAttr || boolAttr) {
|
||||
int64_t value{};
|
||||
if (intAttr)
|
||||
@@ -186,10 +186,10 @@ public:
|
||||
/*knownConstantValue=*/{value});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
auto splatAttr = op.getValue().template dyn_cast<SplatElementsAttr>();
|
||||
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
|
||||
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
int64_t value = splatAttr.template getSplatValue<APInt>().getZExtValue();
|
||||
TensorType ty = splatAttr.getType().template cast<TensorType>();
|
||||
return AxisInfo(
|
||||
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
/*divisibility=*/
|
||||
@@ -233,7 +233,8 @@ private:
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
|
||||
std::is_same_v<OpTy, triton::AddPtrOp>) {
|
||||
std::is_same_v<OpTy, triton::AddPtrOp> ||
|
||||
std::is_same_v<OpTy, LLVM::AddOp>) {
|
||||
return {lhs.getConstantValue().value() +
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
|
||||
@@ -334,14 +335,11 @@ private:
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
lhs.getConstantValue().value() == 0)
|
||||
return lhs.getDivisibility(dim);
|
||||
// Case 2: rhs is constant
|
||||
if (rhs.getConstantValue().has_value()) {
|
||||
auto lhsDivisibility = lhs.getDivisibility(dim);
|
||||
auto rhsValue = rhs.getConstantValue().value();
|
||||
if (lhsDivisibility % rhsValue == 0)
|
||||
return lhsDivisibility / rhsValue;
|
||||
}
|
||||
// Case 3: both are not constant
|
||||
// Case 2: rhs is 1
|
||||
if (rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 1)
|
||||
return lhs.getDivisibility(dim);
|
||||
// otherwise: return 1
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -815,11 +813,15 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
|
||||
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
|
||||
CastOpAxisInfoVisitor<triton::BitcastOp>>();
|
||||
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
|
||||
// when scf.for supports integers induction variable
|
||||
visitors.append<MakeRangeOpAxisInfoVisitor>();
|
||||
visitors.append<ConstantOpAxisInfoVisitor>();
|
||||
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
|
||||
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
|
||||
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::AddIOp>,
|
||||
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
|
||||
AddSubOpAxisInfoVisitor<arith::SubIOp>,
|
||||
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
|
||||
visitors.append<MulIOpAxisInfoVisitor>();
|
||||
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
|
||||
DivOpAxisInfoVisitor<arith::DivUIOp>>();
|
||||
|
||||
@@ -82,7 +82,7 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
return failure();
|
||||
ret = sliceEncoding.getParent();
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
if (isa<triton::ViewOp, triton::CatOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -179,8 +179,6 @@ int simulateBackwardRematerialization(
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
|
||||
@@ -275,7 +275,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||
/*optLevel=*/0, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
@@ -328,7 +328,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// llvm::outs() << module << "\n";
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmIR) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
|
||||
@@ -83,8 +83,8 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
if p.syspath_var_name in os.environ:
|
||||
package_dir = os.environ[p.syspath_var_name]
|
||||
version_file_path = os.path.join(package_dir, "version.txt")
|
||||
if not os.path.exists(version_file_path) or\
|
||||
Path(version_file_path).read_text() != p.url:
|
||||
if p.syspath_var_name not in os.environ and\
|
||||
(not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url):
|
||||
try:
|
||||
shutil.rmtree(package_root_dir)
|
||||
except Exception:
|
||||
|
||||
@@ -56,7 +56,7 @@ matmul_data = {
|
||||
'a100': {
|
||||
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
|
||||
(2048, 2048, 2048): {'float16': 0.59, 'float32': 0.57, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
|
||||
@@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
def test_unsigned_name_mangling(device='cuda'):
|
||||
# Test that uint32 and int32 are mangled differently by the compiler
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
out1 = tl.abs(x) # uint32 -> nop
|
||||
out2 = tl.abs(-y) # int32 -> should have an effect
|
||||
tl.store(O1 + off, out1)
|
||||
tl.store(O2 + off, out2)
|
||||
|
||||
dtype_x = 'uint32'
|
||||
dtype_y = 'int32'
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
# reference result
|
||||
expect = (np.abs(x), np.abs(-y))
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
||||
actual = tuple(
|
||||
to_triton(np.empty_like(e), device=device)
|
||||
for e in expect
|
||||
)
|
||||
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# Bitwise op, so expect exact equality
|
||||
assert (expect[0] == to_numpy(actual[0])).all()
|
||||
assert (expect[1] == to_numpy(actual[1])).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@@ -1140,7 +1177,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[128, 128, 64, 2],
|
||||
# triggers nvptx/ptxas bug on V100 curently
|
||||
# [128, 128, 64, 2],
|
||||
[64, 128, 128, 2]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""isort:skip_file"""
|
||||
__version__ = '2.0.0'
|
||||
__version__ = '2.1.0'
|
||||
|
||||
# ---------------------------------------
|
||||
# Note: import order is significant here.
|
||||
|
||||
@@ -63,7 +63,9 @@ def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
if ty.is_int():
|
||||
return 'i' + str(ty.int_bitwidth)
|
||||
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
|
||||
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
|
||||
return prefix + str(ty.int_bitwidth)
|
||||
if ty.is_fp8():
|
||||
return 'fp8'
|
||||
if ty.is_fp16():
|
||||
@@ -908,6 +910,22 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_NoneType(self, node):
|
||||
return None
|
||||
|
||||
def visit_JoinedStr(self, node):
|
||||
values = list(node.values)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, ast.Constant):
|
||||
values[i] = str(value.value)
|
||||
elif isinstance(value, ast.FormattedValue):
|
||||
conversion_code = value.conversion
|
||||
evaluated = self.visit(value.value)
|
||||
if not isinstance(evaluated, triton.language.constexpr):
|
||||
raise NotImplementedError("Cannot evaluate f-string containing non-constexpr conversion values,"
|
||||
" found conversion of type " + str(type(evaluated)))
|
||||
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
|
||||
else:
|
||||
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
|
||||
return ''.join(values)
|
||||
|
||||
def visit(self, node):
|
||||
if node is not None:
|
||||
self.last_node = node
|
||||
@@ -1001,7 +1019,7 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
|
||||
generator.visit(fn.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
if node is None or isinstance(e, CompilationError):
|
||||
raise e
|
||||
raise CompilationError(fn.src, node) from e
|
||||
ret = generator.module
|
||||
|
||||
@@ -37,8 +37,8 @@ def _to_tensor(x, builder):
|
||||
|
||||
|
||||
class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
@@ -181,8 +181,8 @@ def _dsd_kernel(
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
a = tl.load(pa, mask=True)
|
||||
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
||||
a = tl.load(pa)
|
||||
b = tl.load(pb)
|
||||
acc += tl.dot(a, b)
|
||||
pa += inc_a
|
||||
pb += inc_b * stride_bk
|
||||
|
||||
@@ -316,7 +316,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
from triton.language.core import \
|
||||
constexpr # import here rather than at module level due to circular import tangle
|
||||
self.constexprs = [index for index, ty in self.annotations.items() if issubclass(ty, constexpr)]
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
|
||||
@@ -106,10 +106,8 @@ def allclose(x, y, atol=0, rtol=1e-2):
|
||||
return torch.sum(x ^ y) == 0
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
rtol = 0
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
y_max = torch.max(y)
|
||||
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
|
||||
atol = 0
|
||||
return torch.allclose(x, y, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
|
||||
@@ -82,7 +82,7 @@ func.func @div() {
|
||||
%3 = arith.divui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%5 = arith.divsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%6 = arith.divsi %4, %0 : tensor<128xi32>
|
||||
@@ -94,11 +94,12 @@ func.func @div() {
|
||||
%9 = arith.divui %0, %8 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
|
||||
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%11 = arith.divsi %10, %4 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @rem
|
||||
@@ -179,11 +180,11 @@ func.func @logic() {
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%1 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%3 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%4 = arith.divsi %0, %3 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%5 = arith.andi %0, %1 : tensor<128xi32>
|
||||
|
||||
Reference in New Issue
Block a user