mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 13:07:59 -05:00
actual correctness for ntt64
This commit is contained in:
@@ -423,43 +423,67 @@ namespace mxntt {
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
printf(
|
||||
"T BEFORE: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
threadIdx.x,
|
||||
engine.X[0].limbs_storage.limbs[0],
|
||||
engine.X[0].limbs_storage.limbs[1],
|
||||
engine.X[0].limbs_storage.limbs[2],
|
||||
engine.X[0].limbs_storage.limbs[3],
|
||||
engine.X[0].limbs_storage.limbs[4],
|
||||
engine.X[0].limbs_storage.limbs[5],
|
||||
engine.X[0].limbs_storage.limbs[6],
|
||||
engine.X[0].limbs_storage.limbs[7]
|
||||
);
|
||||
// printf(
|
||||
// "T BEFORE: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
// threadIdx.x,
|
||||
// engine.X[0].limbs_storage.limbs[0],
|
||||
// engine.X[0].limbs_storage.limbs[1],
|
||||
// engine.X[0].limbs_storage.limbs[2],
|
||||
// engine.X[0].limbs_storage.limbs[3],
|
||||
// engine.X[0].limbs_storage.limbs[4],
|
||||
// engine.X[0].limbs_storage.limbs[5],
|
||||
// engine.X[0].limbs_storage.limbs[6],
|
||||
// engine.X[0].limbs_storage.limbs[7]
|
||||
// );
|
||||
// }
|
||||
engine.loadBasicTwiddlesGeneric64(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv, false);
|
||||
#pragma unroll 1
|
||||
for (uint32_t phase = 0; phase < 2; phase++) {
|
||||
printf(
|
||||
"T BEFORE: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
threadIdx.x,
|
||||
engine.X[0].limbs_storage.limbs[0],
|
||||
engine.X[1].limbs_storage.limbs[0],
|
||||
engine.X[2].limbs_storage.limbs[0],
|
||||
engine.X[3].limbs_storage.limbs[0],
|
||||
engine.X[4].limbs_storage.limbs[0],
|
||||
engine.X[5].limbs_storage.limbs[0],
|
||||
engine.X[6].limbs_storage.limbs[0],
|
||||
engine.X[7].limbs_storage.limbs[0]
|
||||
);
|
||||
engine.ntt8();
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
printf(
|
||||
"T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
threadIdx.x,
|
||||
engine.X[0].limbs_storage.limbs[0],
|
||||
engine.X[0].limbs_storage.limbs[1],
|
||||
engine.X[0].limbs_storage.limbs[2],
|
||||
engine.X[0].limbs_storage.limbs[3],
|
||||
engine.X[0].limbs_storage.limbs[4],
|
||||
engine.X[0].limbs_storage.limbs[5],
|
||||
engine.X[0].limbs_storage.limbs[6],
|
||||
engine.X[0].limbs_storage.limbs[7]
|
||||
);
|
||||
// printf(
|
||||
// "T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
// threadIdx.x,
|
||||
// engine.X[0].limbs_storage.limbs[0],
|
||||
// engine.X[1].limbs_storage.limbs[0],
|
||||
// engine.X[2].limbs_storage.limbs[0],
|
||||
// engine.X[3].limbs_storage.limbs[0],
|
||||
// engine.X[4].limbs_storage.limbs[0],
|
||||
// engine.X[5].limbs_storage.limbs[0],
|
||||
// engine.X[6].limbs_storage.limbs[0],
|
||||
// engine.X[7].limbs_storage.limbs[0]
|
||||
// );
|
||||
// }
|
||||
if (phase == 0) {
|
||||
engine.loadBasicTwiddlesGeneric64(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv, true);
|
||||
engine.SharedData64Columns8(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData64Rows8(shmem, false, false, strided); // load
|
||||
printf(
|
||||
"T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
|
||||
threadIdx.x,
|
||||
engine.X[0].limbs_storage.limbs[0],
|
||||
engine.X[1].limbs_storage.limbs[0],
|
||||
engine.X[2].limbs_storage.limbs[0],
|
||||
engine.X[3].limbs_storage.limbs[0],
|
||||
engine.X[4].limbs_storage.limbs[0],
|
||||
engine.X[5].limbs_storage.limbs[0],
|
||||
engine.X[6].limbs_storage.limbs[0],
|
||||
engine.X[7].limbs_storage.limbs[0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,7 +724,7 @@ namespace mxntt {
|
||||
|
||||
int stage = log_size - 1;
|
||||
uint32_t stage_rev = 0;
|
||||
S* stage_ptr = basic_twiddles;
|
||||
S* stage_ptr = basic_twiddles + (stage * (1 << stage));
|
||||
const int NOF_BLOCKS = (stage >= 8) ? (1 << (stage - 8)) : 1;
|
||||
const int NOF_THREADS = (stage >= 8) ? 256 : (1 << stage);
|
||||
// std::cout << "Stage: " << stage << "; nof_blocks: " << NOF_BLOCKS << "; nof_threads: " << NOF_THREADS << "; step:
|
||||
@@ -709,7 +733,7 @@ namespace mxntt {
|
||||
CHK_IF_RETURN(cudaPeekAtLastError());
|
||||
|
||||
for (--stage; stage >= 0; stage--) {
|
||||
stage_ptr += 1 << (log_size - 1);
|
||||
stage_ptr -= 1 << (log_size - 1);
|
||||
stage_rev++;
|
||||
// std::cout << "Stage: " << stage << "; nof_blocks: " << NOF_BLOCKS << "; nof_threads: " << NOF_THREADS << ";
|
||||
// step: " << step << "; temp_root: " << temp_root <<"; stage_ptr: " << stage_ptr<< std::endl;
|
||||
|
||||
@@ -72,12 +72,12 @@ public:
|
||||
uint32_t block_offset = s_meta.ntt_inp_id * 4;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (phase) {
|
||||
exp = phase_offset + s_meta.ntt_inp_id + (stage * 4 + i) * 8;
|
||||
exp = phase_offset + stage_offset + block_offset + i;
|
||||
} else {
|
||||
exp = stage_offset + block_offset + i;
|
||||
exp = s_meta.ntt_inp_id + (stage * 4 + i) * 8;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// if (threadIdx.x == 0) {
|
||||
printf(
|
||||
"T: %d, I: %d, stage_offset: %d, block_offset: %d, exp: %d, tw: 0x%x\n",
|
||||
threadIdx.x,
|
||||
@@ -87,7 +87,7 @@ public:
|
||||
exp,
|
||||
basic_twiddles[exp].limbs_storage.limbs[0]
|
||||
);
|
||||
}
|
||||
// }
|
||||
|
||||
WB[stage * 4 + i] = basic_twiddles[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
|
||||
}
|
||||
@@ -130,12 +130,12 @@ public:
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * s_meta.th_stride;
|
||||
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
|
||||
}
|
||||
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[i * data_stride_u64];
|
||||
X[i] = data[s_meta.th_stride * i * data_stride_u64];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,16 +404,16 @@ public:
|
||||
{
|
||||
E T;
|
||||
|
||||
// Stage 2
|
||||
X[1] = X[1] * WB[0];
|
||||
X[3] = X[3] * WB[1];
|
||||
X[5] = X[5] * WB[2];
|
||||
// Stage 0
|
||||
X[4] = X[4] * WB[0];
|
||||
X[5] = X[5] * WB[1];
|
||||
X[6] = X[6] * WB[2];
|
||||
X[7] = X[7] * WB[3];
|
||||
|
||||
BF(T, X[0], X[1]);
|
||||
BF(T, X[2], X[3]);
|
||||
BF(T, X[4], X[5]);
|
||||
BF(T, X[6], X[7]);
|
||||
BF(T, X[0], X[4]);
|
||||
BF(T, X[1], X[5]);
|
||||
BF(T, X[2], X[6]);
|
||||
BF(T, X[3], X[7]);
|
||||
|
||||
// Stage 1
|
||||
X[2] = X[2] * WB[4];
|
||||
@@ -426,16 +426,16 @@ public:
|
||||
BF(T, X[4], X[6]);
|
||||
BF(T, X[5], X[7]);
|
||||
|
||||
// Stage 0
|
||||
X[4] = X[4] * WB[8];
|
||||
X[5] = X[5] * WB[9];
|
||||
X[6] = X[6] * WB[10];
|
||||
// Stage 2
|
||||
X[1] = X[1] * WB[8];
|
||||
X[3] = X[3] * WB[9];
|
||||
X[5] = X[5] * WB[10];
|
||||
X[7] = X[7] * WB[11];
|
||||
|
||||
BF(T, X[0], X[4]);
|
||||
BF(T, X[1], X[5]);
|
||||
BF(T, X[2], X[6]);
|
||||
BF(T, X[3], X[7]);
|
||||
BF(T, X[0], X[1]);
|
||||
BF(T, X[2], X[3]);
|
||||
BF(T, X[4], X[5]);
|
||||
BF(T, X[6], X[7]);
|
||||
}
|
||||
|
||||
DEVICE_INLINE void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
|
||||
Reference in New Issue
Block a user