mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Remove wrong dependency between TritonGPU and NVGPU dialect (#2276)
This commit is contained in:
@@ -40,5 +40,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
|
||||
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
|
||||
mlir::gpu::GPUDialect>();
|
||||
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
|
||||
mlir::triton::nvgpu::NVGPUDialect>();
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
|
||||
@@ -16,7 +16,6 @@ def TritonGPU_Dialect : Dialect {
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::triton::nvgpu::NVGPUDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
|
||||
@@ -6,4 +6,5 @@ add_mlir_dialect_library(NVGPUIR
|
||||
NVGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLLVMDialect
|
||||
)
|
||||
|
||||
@@ -36,5 +36,7 @@ def test_op(M, N, dtype, mode):
|
||||
x.grad = None
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
|
||||
torch.testing.assert_close(th_dx, tt_dx)
|
||||
if dtype == 'float16':
|
||||
torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001)
|
||||
else:
|
||||
torch.testing.assert_close(th_dx, tt_dx)
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
*/
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
||||
#include "DumpLayout.h"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user