mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix cast when both src_ty and dst_ty are of block_type (#1301)
Commonly used in atomic_rmw ops
This commit is contained in:
@@ -683,6 +683,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
shape = (8, 8)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
||||
val = offs.to(tl.float32)
|
||||
x = X + offs
|
||||
tl.atomic_min(x, val)
|
||||
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
||||
kernel[(2,)](x, shape[0], shape[1])
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user