mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix int8 dot (#1435)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using CoordTy = SmallVector<Value>;
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../ConvertLayoutOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
@@ -585,7 +585,8 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth);
|
||||
int numRepM = numRep[0], numRepK = numRep[1];
|
||||
int numRepM = numRep[0];
|
||||
int numRepK = numRep[1];
|
||||
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "DotOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -420,11 +420,11 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
|
||||
assert(mmaParent.isAmpere());
|
||||
if (getOpIdx() == 0)
|
||||
return {std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])),
|
||||
shape[1] / shapePerWarp[2]};
|
||||
std::max<int64_t>(1, shape[1] / shapePerWarp[2])};
|
||||
else {
|
||||
assert(getOpIdx() == 1);
|
||||
return {
|
||||
shape[0] / shapePerWarp[2],
|
||||
std::max<int64_t>(1, shape[0] / shapePerWarp[2]),
|
||||
std::max<int64_t>(1, shape[1] / (shapePerWarp[1] * warpsPerCTA[1]))};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1284,8 +1284,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
# triggers nvptx/ptxas bug on V100 currently
|
||||
# [128, 128, 64, 2],
|
||||
[64, 64, 32, 4],
|
||||
[128, 128, 64, 2],
|
||||
[64, 128, 128, 2]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
|
||||
@@ -1183,6 +1183,9 @@ def dot(lhs: tl.tensor,
|
||||
and rhs.shape[1].value >= 16,\
|
||||
"small blocks not supported!"
|
||||
if lhs.type.scalar.is_int():
|
||||
assert lhs.type.scalar == tl.int8, "only int8 supported!"
|
||||
# TODO: This is CUDA specific, check if ROCm has the same limitation
|
||||
assert lhs.shape[1].value >= 32, "small blocks not supported!"
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
||||
|
||||
Reference in New Issue
Block a user