mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(compiler): add lowering of batched mapped bootstrap operations to wrappers and SDFG, with support in the runtime.
This commit is contained in:
@@ -27,6 +27,7 @@ def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_BatchLutTensor : 2DTensorOf<[I64]>;
|
||||
|
||||
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
|
||||
@@ -35,6 +36,7 @@ def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_BatchLutBuffer : MemRefRankOf<[I64], [2]>;
|
||||
|
||||
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
@@ -359,6 +361,38 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMappedBootstrapLweTensorOp : Concrete_Op<"batched_mapped_bootstrap_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched, mapped version of BootstrapLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweTensor:$input_ciphertext,
|
||||
Concrete_BatchLutTensor:$lookup_table_vector,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMappedBootstrapLweBufferOp : Concrete_Op<"batched_mapped_bootstrap_lwe_buffer"> {
|
||||
let summary = "Batched, mapped version of BootstrapLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$input_ciphertext,
|
||||
Concrete_BatchLutBuffer:$lookup_table_vector,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$bskIndex
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]> {
|
||||
let summary = "Keyswitches an LWE ciphertext";
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ def ProcessKindBatchMulEintIntCst : I32EnumAttrCase<"batched_mul_eint_int_cst",
|
||||
def ProcessKindBatchNegEint : I32EnumAttrCase<"batched_neg_eint", 11>;
|
||||
def ProcessKindBatchKeyswitch : I32EnumAttrCase<"batched_keyswitch", 12>;
|
||||
def ProcessKindBatchBootstrap : I32EnumAttrCase<"batched_bootstrap", 13>;
|
||||
def ProcessKindBatchMapBootstrap : I32EnumAttrCase<"batched_mapped_bootstrap", 14>;
|
||||
|
||||
def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind",
|
||||
[ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt,
|
||||
@@ -104,7 +105,8 @@ def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind",
|
||||
ProcessKindBatchAddEint, ProcessKindBatchAddEintInt,
|
||||
ProcessKindBatchAddEintIntCst, ProcessKindBatchMulEintInt,
|
||||
ProcessKindBatchMulEintIntCst, ProcessKindBatchNegEint,
|
||||
ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap]> {
|
||||
ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap,
|
||||
ProcessKindBatchMapBootstrap]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::concretelang::SDFG";
|
||||
}
|
||||
|
||||
@@ -36,12 +36,12 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
void *sout);
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
uint32_t ksk_index, void *context);
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
uint32_t output_size, void *context);
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, uint32_t bsk_index, void *context);
|
||||
uint32_t bsk_index, uint32_t output_size, void *context);
|
||||
|
||||
void stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
@@ -57,12 +57,16 @@ void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sout);
|
||||
void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context);
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
uint32_t output_size, void *context);
|
||||
void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context);
|
||||
uint32_t bsk_index, uint32_t output_size, void *context);
|
||||
void stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t bsk_index, uint32_t output_size, void *context);
|
||||
|
||||
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
|
||||
void stream_emulator_put_uint64(void *stream, uint64_t e);
|
||||
|
||||
@@ -41,6 +41,8 @@ char memref_batched_negate_lwe_ciphertext_u64[] =
|
||||
"memref_batched_negate_lwe_ciphertext_u64";
|
||||
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
|
||||
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
|
||||
char memref_batched_mapped_bootstrap_lwe_u64[] =
|
||||
"memref_batched_mapped_bootstrap_lwe_u64";
|
||||
|
||||
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
|
||||
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
@@ -51,6 +53,8 @@ char memref_batched_keyswitch_lwe_cuda_u64[] =
|
||||
"memref_batched_keyswitch_lwe_cuda_u64";
|
||||
char memref_batched_bootstrap_lwe_cuda_u64[] =
|
||||
"memref_batched_bootstrap_lwe_cuda_u64";
|
||||
char memref_batched_mapped_bootstrap_lwe_cuda_u64[] =
|
||||
"memref_batched_mapped_bootstrap_lwe_cuda_u64";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
@@ -175,6 +179,13 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_mapped_bootstrap_lwe_u64 ||
|
||||
funcName == memref_batched_mapped_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref2DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
@@ -584,6 +595,11 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
memref_batched_bootstrap_lwe_cuda_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<Concrete::BatchedBootstrapLweBufferOp>);
|
||||
patterns.add<ConcreteToCAPICallPattern<
|
||||
Concrete::BatchedMappedBootstrapLweBufferOp,
|
||||
memref_batched_mapped_bootstrap_lwe_cuda_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<Concrete::BatchedMappedBootstrapLweBufferOp>);
|
||||
} else {
|
||||
patterns.add<ConcreteToCAPICallPattern<Concrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(
|
||||
@@ -601,6 +617,11 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
memref_batched_bootstrap_lwe_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<Concrete::BatchedBootstrapLweBufferOp>);
|
||||
patterns.add<
|
||||
ConcreteToCAPICallPattern<Concrete::BatchedMappedBootstrapLweBufferOp,
|
||||
memref_batched_mapped_bootstrap_lwe_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<Concrete::BatchedMappedBootstrapLweBufferOp>);
|
||||
}
|
||||
|
||||
patterns.add<ConcreteToCAPICallPattern<Concrete::WopPBSCRTLweBufferOp,
|
||||
|
||||
@@ -62,6 +62,8 @@ char stream_emulator_make_memref_batched_keyswitch_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_keyswitch_lwe_u64_process";
|
||||
char stream_emulator_make_memref_batched_bootstrap_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_bootstrap_lwe_u64_process";
|
||||
char stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process";
|
||||
|
||||
char stream_emulator_make_memref_stream[] =
|
||||
"stream_emulator_make_memref_stream";
|
||||
@@ -242,6 +244,11 @@ struct LowerSDFGMakeProcess
|
||||
case SDFG::ProcessKind::batched_bootstrap:
|
||||
funcName = stream_emulator_make_memref_batched_bootstrap_lwe_u64_process;
|
||||
[[fallthrough]];
|
||||
case SDFG::ProcessKind::batched_mapped_bootstrap:
|
||||
if (funcName == nullptr)
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process;
|
||||
[[fallthrough]];
|
||||
case SDFG::ProcessKind::bootstrap:
|
||||
if (funcName == nullptr)
|
||||
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
|
||||
|
||||
@@ -219,6 +219,43 @@ struct BatchedBootstrapGLWEOpPattern
|
||||
}
|
||||
};
|
||||
|
||||
struct BatchedMappedBootstrapGLWEOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::BatchedMappedBootstrapGLWEOp> {
|
||||
|
||||
BatchedMappedBootstrapGLWEOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter)
|
||||
: mlir::OpConversionPattern<TFHE::BatchedMappedBootstrapGLWEOp>(
|
||||
typeConverter, context,
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::BatchedMappedBootstrapGLWEOp bmbsOp,
|
||||
TFHE::BatchedMappedBootstrapGLWEOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
TFHE::GLWECipherTextType inputElementType =
|
||||
bmbsOp.getCiphertexts()
|
||||
.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<TFHE::GLWECipherTextType>();
|
||||
|
||||
auto polySize = adaptor.getKey().getPolySize();
|
||||
auto glweDimension = adaptor.getKey().getGlweDim();
|
||||
auto levels = adaptor.getKey().getLevels();
|
||||
auto baseLog = adaptor.getKey().getBaseLog();
|
||||
auto inputLweDimension =
|
||||
inputElementType.getKey().getNormalized().value().dimension;
|
||||
auto bskIndex = bmbsOp.getKeyAttr().getIndex();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::BatchedMappedBootstrapLweTensorOp>(
|
||||
bmbsOp, this->getTypeConverter()->convertType(bmbsOp.getType()),
|
||||
adaptor.getCiphertexts(), adaptor.getLookupTable(), inputLweDimension,
|
||||
polySize, levels, baseLog, glweDimension, bskIndex);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct KeySwitchGLWEOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::KeySwitchGLWEOp> {
|
||||
|
||||
@@ -811,7 +848,8 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
patterns.insert<ZeroOpPattern<mlir::concretelang::TFHE::ZeroGLWEOp>,
|
||||
ZeroOpPattern<mlir::concretelang::TFHE::ZeroTensorGLWEOp>,
|
||||
SubIntGLWEOpPattern, BootstrapGLWEOpPattern,
|
||||
BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern,
|
||||
BatchedBootstrapGLWEOpPattern,
|
||||
BatchedMappedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern,
|
||||
BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>(
|
||||
&getContext(), converter);
|
||||
|
||||
|
||||
@@ -159,6 +159,11 @@ void mlir::concretelang::Concrete::
|
||||
Concrete::BatchedBootstrapLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedBootstrapLweTensorOp,
|
||||
Concrete::BatchedBootstrapLweBufferOp>>(*ctx);
|
||||
// batched_mapped_bootstrap_lwe_tensor =>
|
||||
// batched_mapped_bootstrap_lwe_buffer
|
||||
Concrete::BatchedMappedBootstrapLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedMappedBootstrapLweTensorOp,
|
||||
Concrete::BatchedMappedBootstrapLweBufferOp>>(*ctx);
|
||||
// wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer
|
||||
Concrete::WopPBSCRTLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
Concrete::WopPBSCRTLweTensorOp, Concrete::WopPBSCRTLweBufferOp>>(*ctx);
|
||||
|
||||
@@ -98,6 +98,8 @@ mlir::LogicalResult MakeProcess::verify() {
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::batched_bootstrap:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_mapped_bootstrap:
|
||||
return checkStreams(2, 1);
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
|
||||
@@ -29,6 +29,7 @@ char batched_mul_eint_int_cst[] = "batched_mul_eint_int_cst";
|
||||
char batched_neg_eint[] = "batched_neg_eint";
|
||||
char batched_keyswitch[] = "batched_keyswitch";
|
||||
char batched_bootstrap[] = "batched_bootstrap";
|
||||
char batched_mapped_bootstrap[] = "batched_mapped_bootstrap";
|
||||
} // namespace
|
||||
|
||||
template <typename Op, char const *processName, bool copyAttributes = false>
|
||||
@@ -129,6 +130,10 @@ void registerSDFGConvertibleOpInterfaceExternalModels(
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp,
|
||||
batched_bootstrap, true>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedMappedBootstrapLweTensorOp::
|
||||
attachInterface<ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedMappedBootstrapLweTensorOp,
|
||||
batched_mapped_bootstrap, true>>(*ctx);
|
||||
});
|
||||
}
|
||||
} // namespace SDFG
|
||||
|
||||
@@ -184,13 +184,18 @@ struct Process {
|
||||
Param output_lwe_dim;
|
||||
Param poly_size;
|
||||
Param glwe_dim;
|
||||
Param sk_index;
|
||||
Param output_size;
|
||||
Context ctx;
|
||||
void (*fun)(Process *);
|
||||
char name[80];
|
||||
};
|
||||
|
||||
static inline void schedule_kernel(Process *p) { p->fun(p); }
|
||||
static inline void schedule_kernel(Process *p) {
|
||||
std::cout << " Scheduling a " << p->name << " on GPU " << p->dfg->gpu_idx
|
||||
<< "\n";
|
||||
p->fun(p);
|
||||
}
|
||||
|
||||
struct Stream {
|
||||
stream_type type;
|
||||
@@ -368,7 +373,7 @@ sdfg_gpu_debug_compare_memref(MemRef2 &a, MemRef2 &b, char const *msg) {
|
||||
a.strides[0] != b.strides[0] || a.strides[1] != b.strides[1])
|
||||
return false;
|
||||
size_t data_size = memref_get_data_size(a);
|
||||
for (int i = 0; i < data_size / sizeof(uint64_t); ++i)
|
||||
for (size_t i = 0; i < data_size / sizeof(uint64_t); ++i)
|
||||
if ((a.aligned + a.offset)[i] != (b.aligned + b.offset)[i]) {
|
||||
std::cout << msg << " - memrefs differ at position " << i << " "
|
||||
<< (a.aligned + a.offset)[i] << " " << (b.aligned + b.offset)[i]
|
||||
@@ -380,6 +385,7 @@ sdfg_gpu_debug_compare_memref(MemRef2 &a, MemRef2 &b, char const *msg) {
|
||||
|
||||
// Stream emulator processes
|
||||
void memref_keyswitch_lwe_u64_process(Process *p) {
|
||||
assert(p->sk_index.val == 0 && "multiple ksk is not yet implemented on GPU");
|
||||
Dependence *idep = p->input_streams[0]->get(p->dfg->gpu_idx);
|
||||
uint64_t num_samples = idep->host_data.sizes[0];
|
||||
MemRef2 out = {
|
||||
@@ -402,6 +408,7 @@ void memref_keyswitch_lwe_u64_process(Process *p) {
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
assert(p->sk_index.val == 0 && "multiple bsk is not yet implemented on GPU");
|
||||
assert(p->output_size.val == p->glwe_dim.val * p->poly_size.val + 1);
|
||||
void *fbsk_gpu = p->ctx.val->get_bsk_gpu(
|
||||
p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->glwe_dim.val,
|
||||
@@ -409,18 +416,23 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
Dependence *idep0 = p->input_streams[0]->get(p->dfg->gpu_idx);
|
||||
void *ct0_gpu = idep0->device_data;
|
||||
|
||||
uint64_t glwe_ct_len = p->poly_size.val * (p->glwe_dim.val + 1);
|
||||
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
|
||||
Dependence *idep1 = p->input_streams[1]->get(host_location);
|
||||
MemRef2 &mtlu = idep1->host_data;
|
||||
uint32_t num_lut_vectors = mtlu.sizes[0];
|
||||
uint64_t glwe_ct_len =
|
||||
p->poly_size.val * (p->glwe_dim.val + 1) * num_lut_vectors;
|
||||
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
|
||||
auto tlu = mtlu.aligned + mtlu.offset;
|
||||
// Glwe trivial encryption
|
||||
for (size_t i = 0; i < p->poly_size.val * p->glwe_dim.val; i++) {
|
||||
glwe_ct[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < p->poly_size.val; i++) {
|
||||
glwe_ct[p->poly_size.val * p->glwe_dim.val + i] = tlu[i];
|
||||
size_t pos = 0, postlu = 0;
|
||||
for (size_t l = 0; l < num_lut_vectors; ++l) {
|
||||
for (size_t i = 0; i < p->poly_size.val * p->glwe_dim.val; i++) {
|
||||
glwe_ct[pos++] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < p->poly_size.val; i++) {
|
||||
glwe_ct[pos++] = tlu[postlu++];
|
||||
}
|
||||
}
|
||||
void *glwe_ct_gpu = cuda_malloc_async(
|
||||
glwe_ct_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx);
|
||||
@@ -434,15 +446,21 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
void *out_gpu = cuda_malloc_async(
|
||||
data_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx);
|
||||
cudaMemsetAsync(out_gpu, 0, data_size, *(cudaStream_t *)p->dfg->gpu_stream);
|
||||
|
||||
// Move test vector indexes to the GPU, the test vector indexes is set of 0
|
||||
uint32_t num_test_vectors = 1, lwe_idx = 0,
|
||||
test_vector_idxes_size = num_samples * sizeof(uint64_t);
|
||||
void *test_vector_idxes = malloc(test_vector_idxes_size);
|
||||
memset(test_vector_idxes, 0, test_vector_idxes_size);
|
||||
uint32_t lwe_idx = 0, test_vector_idxes_size = num_samples * sizeof(uint64_t);
|
||||
uint64_t *test_vector_idxes = (uint64_t *)malloc(test_vector_idxes_size);
|
||||
if (num_lut_vectors == 1) {
|
||||
memset((void *)test_vector_idxes, 0, test_vector_idxes_size);
|
||||
} else {
|
||||
assert(num_lut_vectors == num_samples);
|
||||
for (size_t i = 0; i < num_lut_vectors; ++i)
|
||||
test_vector_idxes[i] = i;
|
||||
}
|
||||
void *test_vector_idxes_gpu =
|
||||
cuda_malloc_async(test_vector_idxes_size,
|
||||
(cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, test_vector_idxes,
|
||||
cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, (void *)test_vector_idxes,
|
||||
test_vector_idxes_size,
|
||||
(cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx);
|
||||
// Schedule the bootstrap kernel on the GPU
|
||||
@@ -452,7 +470,7 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
(cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx, out_gpu, glwe_ct_gpu,
|
||||
test_vector_idxes_gpu, ct0_gpu, fbsk_gpu, (int8_t *)pbs_buffer,
|
||||
p->input_lwe_dim.val, p->glwe_dim.val, p->poly_size.val, p->base_log.val,
|
||||
p->level.val, num_samples, num_test_vectors, lwe_idx,
|
||||
p->level.val, num_samples, num_lut_vectors, lwe_idx,
|
||||
cuda_get_max_shared_memory(p->dfg->gpu_idx));
|
||||
cuda_drop_async(test_vector_idxes_gpu, (cudaStream_t *)p->dfg->gpu_stream,
|
||||
p->dfg->gpu_idx);
|
||||
@@ -573,14 +591,15 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context) {
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
uint32_t output_size, void *context) {
|
||||
Process *p =
|
||||
make_process_1_1(dfg, sin1, sout, memref_keyswitch_lwe_u64_process);
|
||||
p->level.val = level;
|
||||
p->base_log.val = base_log;
|
||||
p->input_lwe_dim.val = input_lwe_dim;
|
||||
p->output_lwe_dim.val = output_lwe_dim;
|
||||
p->sk_index.val = ksk_index;
|
||||
p->output_size.val = output_size;
|
||||
p->ctx.val = (RuntimeContext *)context;
|
||||
static int count = 0;
|
||||
@@ -590,7 +609,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context) {
|
||||
uint32_t bsk_index, uint32_t output_size, void *context) {
|
||||
// The TLU does not need to be sent to GPU
|
||||
((Stream *)sin2)->type = TS_STREAM_TYPE_X86_TO_X86_LSAP;
|
||||
Process *p =
|
||||
@@ -600,6 +619,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
p->level.val = level;
|
||||
p->base_log.val = base_log;
|
||||
p->glwe_dim.val = glwe_dim;
|
||||
p->sk_index.val = bsk_index;
|
||||
p->output_size.val = output_size;
|
||||
p->ctx.val = (RuntimeContext *)context;
|
||||
static int count = 0;
|
||||
@@ -642,20 +662,29 @@ void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process(
|
||||
|
||||
void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context) {
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index,
|
||||
uint32_t output_size, void *context) {
|
||||
stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
dfg, sin1, sout, level, base_log, input_lwe_dim, output_lwe_dim,
|
||||
output_size, context);
|
||||
ksk_index, output_size, context);
|
||||
}
|
||||
|
||||
void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context) {
|
||||
uint32_t bsk_index, uint32_t output_size, void *context) {
|
||||
stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
dfg, sin1, sin2, sout, input_lwe_dim, poly_size, level, base_log,
|
||||
glwe_dim, output_size, context);
|
||||
glwe_dim, bsk_index, output_size, context);
|
||||
}
|
||||
|
||||
void stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t bsk_index, uint32_t output_size, void *context) {
|
||||
stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
dfg, sin1, sin2, sout, input_lwe_dim, poly_size, level, base_log,
|
||||
glwe_dim, bsk_index, output_size, context);
|
||||
}
|
||||
|
||||
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype) {
|
||||
|
||||
@@ -247,7 +247,8 @@ void memref_batched_mapped_bootstrap_lwe_cuda_u64(
|
||||
assert(bsk_index == 0 && "multiple bsk is not yet implemented on GPU");
|
||||
assert(out_size0 == ct0_size0);
|
||||
assert(out_size1 == glwe_dim * poly_size + 1);
|
||||
assert((out_size0 == tlu_size0 || tlu_size0 == 1) && "Number of LUTs does not match batch size");
|
||||
assert((out_size0 == tlu_size0 || tlu_size0 == 1) &&
|
||||
"Number of LUTs does not match batch size");
|
||||
// TODO: Multi GPU
|
||||
uint32_t gpu_idx = 0;
|
||||
uint32_t num_samples = out_size0;
|
||||
@@ -291,8 +292,7 @@ void memref_batched_mapped_bootstrap_lwe_cuda_u64(
|
||||
glwe_ct, 0, glwe_ct_size, gpu_idx, (cudaStream_t *)stream);
|
||||
|
||||
// Move test vector indexes to the GPU, the test vector indexes is set of 0
|
||||
uint32_t lwe_idx = 0,
|
||||
test_vector_idxes_size = num_samples * sizeof(uint64_t);
|
||||
uint32_t lwe_idx = 0, test_vector_idxes_size = num_samples * sizeof(uint64_t);
|
||||
uint64_t *test_vector_idxes = (uint64_t *)malloc(test_vector_idxes_size);
|
||||
if (num_lut_vectors == 1) {
|
||||
memset((void *)test_vector_idxes, 0, test_vector_idxes_size);
|
||||
|
||||
@@ -262,3 +262,33 @@ TEST(SDFG_unit_tests, batched_tree) {
|
||||
ASSERT_TRUE(res);
|
||||
ASSERT_EQ_OUTCOME(res, expected);
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, batched_tree_mapped_tlu) {
|
||||
std::string source = R"(
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<3>>, %a1: tensor<3x3xi4>, %a2: tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<4>> {
|
||||
%lut_vec = arith.constant dense<[[1,3,5,7,9,11,13,15],
|
||||
[2,4,6,8,10,12,14,0],
|
||||
[3,6,9,12,15,2,5,8],
|
||||
[4,8,12,0,4,8,12,0]]> : tensor<4x8xi64>
|
||||
%map = arith.constant dense<[[0, 1, 2], [3, 2, 1], [1, 2, 3]]> : tensor<3x3xindex>
|
||||
%b1 = "FHELinalg.add_eint_int"(%t, %a1) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<3>>
|
||||
%b2 = "FHELinalg.add_eint_int"(%t, %a2) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<3>>
|
||||
%c = "FHELinalg.add_eint"(%b1, %b2) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3x!FHE.eint<3>>) -> tensor<3x3x!FHE.eint<3>>
|
||||
%res = "FHELinalg.apply_mapped_lookup_table"(%c, %lut_vec, %map) : (tensor<3x3x!FHE.eint<3>>, tensor<4x8xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)";
|
||||
using tensor2_in = std::array<std::array<uint8_t, 3>, 3>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda =
|
||||
load<TestTypedLambda<tensor2_out, tensor2_in, tensor2_in, tensor2_in>>(
|
||||
outputLib);
|
||||
tensor2_in t = {{{0, 1, 2}, {3, 0, 1}, {2, 3, 0}}};
|
||||
tensor2_in a1 = {{{0, 1, 0}, {0, 1, 0}, {0, 1, 0}}};
|
||||
tensor2_in a2 = {{{1, 0, 1}, {1, 0, 1}, {1, 0, 1}}};
|
||||
tensor2_out expected = {{{3, 8, 2}, {0, 6, 8}, {12, 8, 8}}};
|
||||
auto res = lambda.call(t, a1, a2);
|
||||
ASSERT_TRUE(res);
|
||||
ASSERT_EQ_OUTCOME(res, expected);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user