Add waves_per_eu as kernel parameter (#319)

* Add waves_per_eu as kernel parameter

* Fix failing tests

* Add default value for waves_per_eu for ttgir_to_llir function

* Remove aot.py
This commit is contained in:
oplavsic
2023-10-06 19:08:34 +02:00
committed by GitHub
parent be95edc63f
commit e801638b40
7 changed files with 36 additions and 26 deletions

View File

@@ -126,11 +126,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::LLVMContext llvmContext;
mlir::triton::gpu::TMAMetadataTy tmaInfos;
#ifdef USE_ROCM
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), tmaInfos, Target::ROCDL);
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
tmaInfos, Target::ROCDL, 0 /*wavesPerEU*/);
#else
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), tmaInfos, Target::Default);
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
tmaInfos, Target::Default, 0 /*wavesPerEU*/);
#endif
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";