Initial commit to resolve merge conflicts

rename tl.float8e4 to tl.float8e4nv to align with upstream

ROCM IFU: Fix python arch issues

ROCM IFU: Fix kernel launcher

ROCM IFU: Fix merge conflicts

fix debug build

Set correct threadsPerCTA
This commit is contained in:
Jason Furmanek
2023-09-12 20:43:59 +00:00
parent 74fd8e9754
commit e5d7bb4fae
36 changed files with 414 additions and 1005 deletions

View File

@@ -21,18 +21,8 @@ enum Target { NVVM, ROCDL, Default = NVVM };
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
<<<<<<< HEAD
#ifdef USE_ROCM
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = true);
#else
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);
#endif
=======
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} // namespace triton
} // namespace mlir

View File

@@ -147,11 +147,11 @@ compared to 1*64 when the hasLeadingOffset is false.
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;
return $_get(context, vecSize, perPhase, maxPhase, order);
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {
// Do not swizzle in case k dimension is not innermost.
// In this case accesses will go in different banks even without swizzling.
return $_get(context, 1, 1, 1, order);
return get(context, 1, 1, 1, order, CTALayout);
}
}
#endif
@@ -185,20 +185,12 @@ compared to 1*64 when the hasLeadingOffset is false.
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
<<<<<<< HEAD
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
=======
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
<<<<<<< HEAD
return $_get(context, 1, 1, 1, order);
=======
return get(context, 1, 1, 1, order, CTALayout);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand

View File

@@ -52,7 +52,6 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
<<<<<<< HEAD
static int getSharedSize(ModuleOp mod) {
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
if(!sharedAttr) {
@@ -61,8 +60,6 @@ def TritonGPU_Dialect : Dialect {
return sharedAttr.cast<IntegerAttr>().getInt();
}
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}];
let useDefaultAttributePrinterParser = 1;