MSM - fixed bug in reduction phase (#549)

This PR fixes a bug in the iterative reduction algorithm.
There were unsynchronized threads reading and writing to the same
addresses that caused MSM to fail a small percentage of the time - this is fixed now.
This commit is contained in:
HadarIngonyama
2024-06-30 12:05:55 +03:00
committed by GitHub
parent f812f071fa
commit 4fef542346

View File

@@ -88,6 +88,7 @@ namespace msm {
__global__ void single_stage_multi_reduction_kernel(
const P* v,
P* v_r,
unsigned orig_block_size,
unsigned block_size,
unsigned write_stride,
unsigned buckets_per_bm,
@@ -107,11 +108,11 @@ namespace msm {
// only for write_phase=1 because of its read pattern.
const int shifted_block_id = write_phase ? block_id + (block_id + step) / step : block_id;
const int block_tid = shifted_tid % jump;
const unsigned read_ind = block_size * shifted_block_id + block_tid;
const unsigned read_ind = orig_block_size * shifted_block_id + block_tid;
const unsigned write_ind = jump * shifted_block_id + block_tid;
const unsigned v_r_key =
write_stride ? ((write_ind / buckets_per_bm) * 2 + write_phase) * write_stride + write_ind % buckets_per_bm
: write_ind;
: read_ind;
v_r[v_r_key] = v[read_ind] + v[read_ind + jump];
}
@@ -745,7 +746,7 @@ namespace msm {
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS;
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c);
} else {
// the recursive reduction algorithm works with 2 types of reduction that can run on parallel streams
// the iterative reduction algorithm works with 2 types of reduction that can run on parallel streams
cudaStream_t stream_reduction;
cudaEvent_t event_finished_reduction;
CHK_IF_RETURN(cudaStreamCreate(&stream_reduction));
@@ -766,10 +767,10 @@ namespace msm {
const unsigned target_buckets_count = target_windows_count << target_bits_count; // new_bms*2^new_c
CHK_IF_RETURN(cudaMallocAsync(&target_buckets, sizeof(P) * target_buckets_count * batch_size, stream));
CHK_IF_RETURN(cudaMallocAsync(
&temp_buckets1, sizeof(P) * source_buckets_count / 2 * batch_size,
&temp_buckets1, sizeof(P) * source_buckets_count * batch_size,
stream)); // for type1 reduction (strided, bottom window - evens)
CHK_IF_RETURN(cudaMallocAsync(
&temp_buckets2, sizeof(P) * source_buckets_count / 2 * batch_size,
&temp_buckets2, sizeof(P) * source_buckets_count * batch_size,
stream)); // for type2 reduction (serial, top window - odds)
initialize_buckets_kernel<<<(target_buckets_count * batch_size + 255) / 256, 256>>>(
target_buckets, target_buckets_count * batch_size); // initialization is needed for the odd c case
@@ -788,9 +789,9 @@ namespace msm {
if (!is_odd_c || !is_first_iter) { // skip if c is odd and it's the first iteration
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
is_first_iter || (is_second_iter && is_odd_c) ? source_buckets : temp_buckets1,
is_last_iter ? target_buckets : temp_buckets1, 1 << (source_bits_count - j + (is_odd_c ? 1 : 0)),
is_last_iter ? 1 << target_bits_count : 0, 1 << target_bits_count, 0 /*=write_phase*/,
(1 << target_bits_count) - 1, nof_threads);
is_last_iter ? target_buckets : temp_buckets1, 1 << source_bits_count,
1 << (source_bits_count - j + (is_odd_c ? 1 : 0)), is_last_iter ? 1 << target_bits_count : 0,
1 << target_bits_count, 0 /*=write_phase*/, (1 << target_bits_count) - 1, nof_threads);
}
nof_threads = (((source_windows_count << (source_bits_count - target_bits_count)) - source_windows_count)
@@ -801,7 +802,7 @@ namespace msm {
NUM_BLOCKS = (nof_threads + NUM_THREADS - 1) / NUM_THREADS;
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_reduction>>>(
is_first_iter ? source_buckets : temp_buckets2, is_last_iter ? target_buckets : temp_buckets2,
1 << (target_bits_count - j), is_last_iter ? 1 << target_bits_count : 0,
1 << target_bits_count, 1 << (target_bits_count - j), is_last_iter ? 1 << target_bits_count : 0,
1 << (target_bits_count - (is_odd_c ? 1 : 0)), 1 /*=write_phase*/,
(1 << (target_bits_count - (is_odd_c ? 1 : 0))) - 1, nof_threads);
}