mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Handle scan of function non commutative (#2362)
Make sure we accumulate in the right order for scans so that non commutative operations are handled correctly.
This commit is contained in:
@@ -1716,10 +1716,16 @@ scan_configs = [
|
||||
for type in ['int32', 'float32']
|
||||
for axis in [1, 0]
|
||||
for shape in scan2d_shapes
|
||||
for op in ['cumsum', 'cumprod']
|
||||
for op in ['cumsum', 'cumprod', 'get_first_element']
|
||||
]
|
||||
|
||||
|
||||
@triton.jit
|
||||
# trivial associative but not commutative function
|
||||
def get_first_element(a, b):
|
||||
return a
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
|
||||
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
|
||||
if is_hip():
|
||||
@@ -1735,15 +1741,26 @@ def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
|
||||
if op == 'cumsum' or op == 'cumprod':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
|
||||
else:
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.associative_scan(x, axis={axis}, combine_fn={op})'})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
z = np.empty_like(x)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op]
|
||||
z_dtype_str = dtype_str
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
if op == 'cumsum' or op == 'cumprod':
|
||||
numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op]
|
||||
z_dtype_str = dtype_str
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
else:
|
||||
assert op == 'get_first_element'
|
||||
z_ref = x
|
||||
if axis == 0:
|
||||
z_ref[1:] = x[0]
|
||||
else:
|
||||
z_ref[:, 1:] = x[:, 0:1]
|
||||
# triton result
|
||||
z_tri = to_triton(z, device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
|
||||
|
||||
Reference in New Issue
Block a user