Files
2024-10-09 12:05:59 +02:00

206 lines
6.6 KiB
Python

import os
from functools import partial
import click
import numpy as np
from concrete import fhe
from concrete.fhe import tfhers
from numpy.random import randint
# FIXME: should we move this to Concrete library directly, hidden to the user
def get_tfhers_params_and_type_and_int(precision):
tfhers_params = tfhers.CryptoParams(
lwe_dimension=LWE_DIM,
glwe_dimension=GLWE_DIM,
polynomial_size=POLY_SIZE,
pbs_base_log=PBS_BASE_LOG,
pbs_level=PBS_LEVEL,
lwe_noise_distribution=LWE_NOISE_DISTR,
glwe_noise_distribution=GLWE_NOISE_DISTR,
encryption_key_choice=ENCRYPTION_KEY_CHOICE,
)
tfhers_type = tfhers.TFHERSIntegerType(
is_signed=False,
bit_width=precision,
carry_width=CARRY_WIDTH,
msg_width=MSG_WIDTH,
params=tfhers_params,
)
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type)
return tfhers_params, tfhers_type, tfhers_int
# FIXME Params: users shouldn't change them, should we hide it
LWE_DIM = 909
GLWE_DIM = 1
POLY_SIZE = 4096
PBS_BASE_LOG = 15
PBS_LEVEL = 2
MSG_WIDTH = 2
CARRY_WIDTH = 3
ENCRYPTION_KEY_CHOICE = tfhers.EncryptionKeyChoice.BIG
LWE_NOISE_DISTR = 0
GLWE_NOISE_DISTR = 2.168404344971009e-19
assert GLWE_DIM == 1, "glwe dim must be 1"
# Options: the user can change the following
# FIXME: explain FHEUINT_PRECISION, ie can it be changed
FHEUINT_PRECISION = 8
tfhers_params, tfhers_type, tfhers_int = get_tfhers_params_and_type_and_int(FHEUINT_PRECISION)
# Describe the function you want to apply, on Concrete ciphertexts
def server_side_function_in_concrete(concrete_vars):
t = (concrete_vars[0] + concrete_vars[1]) % 47
t = t + ((2 * concrete_vars[2]) % 47)
t = (t + 47 - (concrete_vars[3] % 47)) % 47
return t
# The user must specify the range of the TFHE-rs inputs
# FIXME: why can't we set the limit at 256? It's needed for FHEUint8
# FIXME(vectorisation): make that we can use a tensor here
inputset_of_tfhe_rs_inputs = [(tfhers_int(randint(128)),
tfhers_int(randint(128)),
tfhers_int(randint(128)),
tfhers_int(randint(128))) for _ in range(100)]
# End of options
# This is the compiled function: user doesn't have to change this, except to
# add more inputs (ie, tfhers_z etc)
# FIXME(vectorisation): make that we can use a tensor here
def function_to_run_in_concrete(tfhers_vars_0, tfhers_vars_1, tfhers_vars_2, tfhers_vars_3):
# Here, tfhers_x and tfhers_y are in TFHE-rs format
# FIXME(vectorisation): make that we can use a tensor here
tfhers_vars = (tfhers_vars_0, tfhers_vars_1, tfhers_vars_2, tfhers_vars_3)
concrete_vars = []
for v in tfhers_vars:
concrete_vars.append(tfhers.to_native(v))
# Here, concrete_vars are in Concrete format
# Here we can apply whatever function we want in Concrete
concrete_res = server_side_function_in_concrete(concrete_vars)
# Here, concrete_res is in Concrete format
tfhers_res = tfhers.from_native(
concrete_res, tfhers_type
) # we have to specify the type we want to convert to
# Here, tfhers_res is in TFHE-rs format
return tfhers_res
# This is where we compile the function with Concrete: user doesn't have to
# change this, except to add more inputs (ie, tfhers_z etc)
def compile_concrete_function():
dic_compilation = {}
# FIXME(vectorisation): make that we can use a tensor here
for i in range(4):
dic_compilation[f"tfhers_vars_{i}"] = "encrypted"
compiler = fhe.Compiler(function_to_run_in_concrete, dic_compilation)
circuit = compiler.compile(inputset_of_tfhe_rs_inputs)
tfhers_bridge = tfhers.new_bridge(circuit=circuit)
return circuit, tfhers_bridge
@click.group()
def cli():
pass
def read_var_from_file(tfhers_bridge, filename, input_idx):
with open(filename, "rb") as f:
buff = f.read()
return tfhers_bridge.import_value(buff, input_idx=input_idx)
@cli.command()
@click.option("-s", "--secret-key", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
# This is where we generate the evaluation key at the Concrete format, from the
# secret key coming from TFHE-rs, on the client side
def keygen(secret_key: str, concrete_keyset_path: str):
"""Concrete Key Generation"""
# Compile the Concrete function
circuit, tfhers_bridge = compile_concrete_function()
if os.path.exists(concrete_keyset_path):
os.remove(concrete_keyset_path)
# Load the initial secret key to use for keygen
with open(
secret_key,
"rb",
) as f:
buff = f.read()
input_idx_to_key = {0: buff, 1: buff}
tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key)
# FIXME: remove the secret key before saving. The secret key can be used for
# debugging but should really be removed in production
circuit.client.keys.save(concrete_keyset_path)
@cli.command()
@click.option("-c", "--rust-ct", type=str, required=True)
@click.option("-o", "--output-rust-ct", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
# This is the actual FHE computation, on the server side
def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str):
"""Run circuit"""
circuit, tfhers_bridge = compile_concrete_function()
if not os.path.exists(concrete_keyset_path):
raise RuntimeError("cannot find keys, you should run keygen before")
circuit.client.keys.load(concrete_keyset_path)
tfhers_vars = []
for i, rust_ct_i in enumerate(rust_ct.split()):
tfhers_vars.append(read_var_from_file(tfhers_bridge, rust_ct_i, input_idx=i))
tfhers_vars = tuple(tfhers_vars)
encrypted_result = circuit.run(tfhers_vars)
# Export the result
buff = tfhers_bridge.export_value(encrypted_result, output_idx=0)
with open(output_rust_ct, "wb") as f:
f.write(buff)
# BCM BEGIN: to debug computations
# FIXME: how does it decrypt? we are on the server side, we shouldn't have
# the secret key. I think it's because the secret key is saved in concrete_keyset_path
# x = circuit.decrypt(tfhers_uint8_x)
# decoded = tfhers_type.decode(x)
# print(f"Concrete decryption result: raw({x}), decoded({decoded})")
# y = circuit.decrypt(tfhers_uint8_y)
# decoded = tfhers_type.decode(y)
# print(f"Concrete decryption result: raw({y}), decoded({decoded})")
# result = circuit.decrypt(encrypted_result)
# decoded = tfhers_type.decode(result)
# print(f"Concrete decryption result: raw({result}), decoded({decoded})")
# BCM END
if __name__ == "__main__":
cli()