mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add a performance test for reductions (#2125)
Also stop promoting integer types as it doesn't give better perf this will allow more vectorization oportuinity in the future.
This commit is contained in:
@@ -307,8 +307,10 @@ private:
|
||||
Operation *yield = block->getTerminator();
|
||||
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
|
||||
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
|
||||
reduceOp->getNumResults() != 1 ||
|
||||
!reduceOp->getResultTypes()[0].isInteger(32))
|
||||
reduceOp->getNumResults() != 1)
|
||||
return std::nullopt;
|
||||
auto intType = reduceOp->getResultTypes()[0].dyn_cast<IntegerType>();
|
||||
if (!intType || intType.getWidth() > 32)
|
||||
return std::nullopt;
|
||||
if (reduceOp->getOperand(0) != block->getArgument(0) ||
|
||||
reduceOp->getOperand(1) != block->getArgument(1))
|
||||
@@ -382,8 +384,19 @@ private:
|
||||
mask = shl(i32_val(bitmask),
|
||||
and_(laneId, i32_val(~(numLaneToReduce - 1))));
|
||||
}
|
||||
acc[0] = rewriter.create<NVVM::ReduxOp>(loc, acc[0].getType(), acc[0],
|
||||
*kind, mask);
|
||||
for (unsigned i = 0; i < acc.size(); ++i) {
|
||||
unsigned bitwidth = acc[i].getType().cast<IntegerType>().getWidth();
|
||||
if (bitwidth < 32) {
|
||||
if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX)
|
||||
acc[i] = sext(i32_ty, acc[i]);
|
||||
else
|
||||
acc[i] = zext(i32_ty, acc[i]);
|
||||
}
|
||||
acc[i] = rewriter.create<NVVM::ReduxOp>(loc, acc[i].getType(), acc[0],
|
||||
*kind, mask);
|
||||
if (bitwidth < 32)
|
||||
acc[i] = trunc(int_ty(bitwidth), acc[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,3 +225,59 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
|
||||
#######################
|
||||
# Reduction
|
||||
#######################
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
# run in a loop to only to make it compute bound.
|
||||
for i in range(100):
|
||||
x = tl.sum(x, axis=0) + y
|
||||
|
||||
tl.store(output_ptr + offsets, x, mask=mask)
|
||||
|
||||
|
||||
reduction_data = {
|
||||
'a100': {
|
||||
1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031},
|
||||
1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys())
|
||||
@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32'])
|
||||
def test_reductions(N, dtype_str):
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
torch.manual_seed(0)
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str]
|
||||
ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
z = torch.empty((N, ), dtype=dtype, device='cuda')
|
||||
if dtype == torch.float16 or dtype == torch.float32:
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
else:
|
||||
info = torch.iinfo(dtype)
|
||||
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
cur_gpu_perf = 100. * 2. * N / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
@@ -1485,8 +1485,6 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op in ('argmin', 'argmax'):
|
||||
return 'int32'
|
||||
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||
return 'int32'
|
||||
if dtype_str == 'bfloat16':
|
||||
return 'float32'
|
||||
return dtype_str
|
||||
|
||||
@@ -1351,11 +1351,6 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
@builtin
|
||||
def _promote_reduction_input(t, _builder=None):
|
||||
scalar_ty = t.type.scalar
|
||||
# input is extended to 32-bits if necessary
|
||||
# this increases numerical accuracy and can be done pretty much for free
|
||||
# on GPUs
|
||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32:
|
||||
return t.to(int32, _builder=_builder)
|
||||
|
||||
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
||||
if scalar_ty is bfloat16:
|
||||
|
||||
@@ -377,9 +377,9 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
|
||||
assert dtype == torch.float16
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
if dtype in [torch.float32, torch.int32]:
|
||||
ops_per_sub_core = 256
|
||||
elif dtype in [torch.float16, torch.bfloat16]:
|
||||
elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
|
||||
ops_per_sub_core = 512
|
||||
elif dtype in [torch.int8, tl.float8e4, tl.float8e4b15, tl.float8e5]:
|
||||
ops_per_sub_core = 1024
|
||||
|
||||
Reference in New Issue
Block a user