mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial commit to resolve merge conflicts
This commit is contained in:
@@ -48,16 +48,11 @@ def test_assert(func: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
<<<<<<< HEAD
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
=======
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
|
||||
>>>>>>> oai/main
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
|
||||
@@ -643,13 +643,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
tt.func @basic_program_id() {
|
||||
<<<<<<< HEAD
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
=======
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
>>>>>>> oai/main
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1533,24 +1528,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
%blockidx = tt.get_program_id x : i32
|
||||
%blockidy = tt.get_program_id y : i32
|
||||
%blockidz = tt.get_program_id z : i32
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.x
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.y
|
||||
// PTX: nvvm.read.ptx.sreg.ctaid.z
|
||||
// GCN: rocdl.workgroup.id.x
|
||||
// GCN: rocdl.workgroup.id.y
|
||||
// GCN: rocdl.workgroup.id.z
|
||||
=======
|
||||
%blockidx = tt.get_program_id x : i32
|
||||
%blockidy = tt.get_program_id y : i32
|
||||
%blockidz = tt.get_program_id z : i32
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.z
|
||||
>>>>>>> oai/main
|
||||
%v0 = arith.addi %blockidx, %blockidy : i32
|
||||
%v1 = arith.addi %v0, %blockidz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
|
||||
Reference in New Issue
Block a user