Remove wrong dependency between TritonGPU and NVGPU dialect (#2276)

This commit is contained in:
Thomas Raoux
2023-09-11 16:30:13 -07:00
committed by GitHub
parent ec4a968d44
commit a9db6b94b9
7 changed files with 9 additions and 5 deletions

View File

@@ -40,5 +40,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect>();
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::triton::nvgpu::NVGPUDialect>();
}

View File

@@ -7,7 +7,6 @@
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"

View File

@@ -16,7 +16,6 @@ def TritonGPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"mlir::triton::nvgpu::NVGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];

View File

@@ -6,6 +6,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators

View File

@@ -6,4 +6,5 @@ add_mlir_dialect_library(NVGPUIR
NVGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRLLVMDialect
)

View File

@@ -36,5 +36,7 @@ def test_op(M, N, dtype, mode):
x.grad = None
th_y.backward(dy)
th_dx = x.grad.clone()
torch.testing.assert_close(th_dx, tt_dx)
if dtype == 'float16':
torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001)
else:
torch.testing.assert_close(th_dx, tt_dx)

View File

@@ -22,6 +22,7 @@
*/
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "DumpLayout.h"