[BACKEND] Fix int8 dot (#1435)

This commit is contained in:
Keren Zhou
2023-03-28 20:18:17 -07:00
committed by GitHub
parent 3342cc1c0c
commit ee593fca0b
9 changed files with 21 additions and 17 deletions

View File

@@ -1,5 +1,5 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
#include "../ConvertLayoutOpToLLVM.h"
#include "../Utility.h"
using ValueTable = std::map<std::pair<int, int>, Value>;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;

View File

@@ -1,5 +1,5 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
#include "../ConvertLayoutOpToLLVM.h"
#include "../Utility.h"
using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;

View File

@@ -1,5 +1,5 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
#include "../ConvertLayoutOpToLLVM.h"
#include "../Utility.h"
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
@@ -585,7 +585,8 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth);
int numRepM = numRep[0], numRepK = numRep[1];
int numRepM = numRep[0];
int numRepK = numRep[1];
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
int wpt0 = mmaLayout.getWarpsPerCTA()[0];

View File

@@ -1,5 +1,5 @@
#include "DotOpToLLVM.h"
#include "Utility.h"
#include "../DotOpToLLVM.h"
#include "../Utility.h"
using namespace mlir;
using namespace mlir::triton;

View File

@@ -1,5 +1,5 @@
#include "DotOpToLLVM.h"
#include "Utility.h"
#include "../DotOpToLLVM.h"
#include "../Utility.h"
using namespace mlir;
using namespace mlir::triton;

View File

@@ -1,5 +1,5 @@
#include "DotOpToLLVM.h"
#include "Utility.h"
#include "../DotOpToLLVM.h"
#include "../Utility.h"
using namespace mlir;
using namespace mlir::triton;

View File

@@ -420,11 +420,11 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
assert(mmaParent.isAmpere());
if (getOpIdx() == 0)
return {std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])),
shape[1] / shapePerWarp[2]};
std::max<int64_t>(1, shape[1] / shapePerWarp[2])};
else {
assert(getOpIdx() == 1);
return {
shape[0] / shapePerWarp[2],
std::max<int64_t>(1, shape[0] / shapePerWarp[2]),
std::max<int64_t>(1, shape[1] / (shapePerWarp[1] * warpsPerCTA[1]))};
}
}

View File

@@ -1284,8 +1284,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[128, 128, 64, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
# triggers nvptx/ptxas bug on V100 currently
# [128, 128, 64, 2],
[64, 64, 32, 4],
[128, 128, 64, 2],
[64, 128, 128, 2]]
for allow_tf32 in [True]
for col_a in [True, False]

View File

@@ -1183,6 +1183,9 @@ def dot(lhs: tl.tensor,
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
assert lhs.type.scalar == tl.int8, "only int8 supported!"
# TODO: This is CUDA specific, check if ROCm has the same limitation
assert lhs.shape[1].value >= 32, "small blocks not supported!"
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():