chore(frontend-python): Formatting

This commit is contained in:
Bourgerie Quentin
2024-09-25 16:49:35 +02:00
committed by Quentin Bourgerie
parent d2d4613afc
commit ccabaaf8f5
2 changed files with 64 additions and 21 deletions

View File

@@ -147,7 +147,9 @@ class Circuit:
initial_keys (Optional[Dict[int, LweSecretKey]] = None):
initial keys to set before keygen
"""
self._module.keygen(force=force, seed=seed, encryption_seed=encryption_seed, initial_keys)
self._module.keygen(
force=force, seed=seed, encryption_seed=encryption_seed, initial_keys=initial_keys
)
def encrypt(
self,

View File

@@ -111,12 +111,21 @@
"\n",
"assert database_output_bits == 5\n",
"\n",
"# For now, we have not compiled our functions so here, all the computations \n",
"# For now, we have not compiled our functions so here, all the computations\n",
"# in the following asserts are done in the clear, just to check the semantic\n",
"# of the functions\n",
"assert get_ith_element_of_database(make_one_hot_vector(0, size=database_length), database) == database[0]\n",
"assert get_ith_element_of_database(make_one_hot_vector(3, size=database_length), database) == database[3]\n",
"assert get_ith_element_of_database(make_one_hot_vector(4, size=database_length), database) == database[4]"
"assert (\n",
" get_ith_element_of_database(make_one_hot_vector(0, size=database_length), database)\n",
" == database[0]\n",
")\n",
"assert (\n",
" get_ith_element_of_database(make_one_hot_vector(3, size=database_length), database)\n",
" == database[3]\n",
")\n",
"assert (\n",
" get_ith_element_of_database(make_one_hot_vector(4, size=database_length), database)\n",
" == database[4]\n",
")"
]
},
{
@@ -176,9 +185,11 @@
" # values in get_ith_element_of_database\n",
" inputset.append((make_one_hot_vector(np.argmax(database), database_length), database))\n",
"\n",
" compiler = fhe.Compiler(get_ith_element_of_database, {\"one_hot_vector\": \"encrypted\", \"database\": \"clear\"})\n",
" compiler = fhe.Compiler(\n",
" get_ith_element_of_database, {\"one_hot_vector\": \"encrypted\", \"database\": \"clear\"}\n",
" )\n",
" circuit = compiler.compile(inputset, **kwargs)\n",
" \n",
"\n",
" return circuit\n",
"\n",
"\n",
@@ -286,7 +297,20 @@
"source": [
"how_many_tests = 10\n",
"\n",
"sample_list = [(4, 8), (4, 16), (8, 8), (8, 16), (9, 8), (9, 16), (10, 4), (10, 8), (12, 4), (12, 8), (14, 4), (14, 8)]\n",
"sample_list = [\n",
" (4, 8),\n",
" (4, 16),\n",
" (8, 8),\n",
" (8, 16),\n",
" (9, 8),\n",
" (9, 16),\n",
" (10, 4),\n",
" (10, 8),\n",
" (12, 4),\n",
" (12, 8),\n",
" (14, 4),\n",
" (14, 8),\n",
"]\n",
"timings_dic = {}\n",
"\n",
"for database_input_bits, database_output_bits in sample_list:\n",
@@ -378,11 +402,16 @@
" )\n",
" for _ in range(inputset_length)\n",
" ]\n",
" compiler = fhe.Compiler(get_ith_element_of_database, {\"one_hot_vector\": \"encrypted\", \n",
" \"database0\": \"clear\", \n",
" \"database1\": \"clear\", \n",
" \"database2\": \"clear\", \n",
" \"database3\": \"clear\"})\n",
" compiler = fhe.Compiler(\n",
" get_ith_element_of_database,\n",
" {\n",
" \"one_hot_vector\": \"encrypted\",\n",
" \"database0\": \"clear\",\n",
" \"database1\": \"clear\",\n",
" \"database2\": \"clear\",\n",
" \"database3\": \"clear\",\n",
" },\n",
" )\n",
" circuit = compiler.compile(inputset, **kwargs)\n",
" return circuit\n",
"\n",
@@ -520,19 +549,31 @@
"# Finding the best combination\n",
"def find_best_combination(expected_total_bits):\n",
" best_combination = None\n",
" \n",
"\n",
" for database_input_bits, database_output_bits in timings_dic.keys():\n",
" remaining_bits = expected_total_bits - database_input_bits\n",
" assert remaining_bits > 0\n",
" number_of_subdatabases = np.ceil(2**remaining_bits / database_output_bits).astype(np.int32)\n",
" estimated_time = np.ceil(number_of_subdatabases * timings_dic[(database_input_bits, database_output_bits)])\n",
" \n",
" print(f\"Estimated time would be {str(estimated_time):>8s} seconds for {str(number_of_subdatabases):>8s} DBs of {(database_input_bits, database_output_bits)}\")\n",
" \n",
" estimated_time = np.ceil(\n",
" number_of_subdatabases * timings_dic[(database_input_bits, database_output_bits)]\n",
" )\n",
"\n",
" print(\n",
" f\"Estimated time would be {str(estimated_time):>8s} seconds for {str(number_of_subdatabases):>8s} DBs of {(database_input_bits, database_output_bits)}\"\n",
" )\n",
"\n",
" if best_combination is None or estimated_time < best_combination[0]:\n",
" best_combination = (estimated_time, number_of_subdatabases, database_input_bits, database_output_bits)\n",
" \n",
" print(f\"\\nBest combination: {best_combination[0]} seconds for a DB of {expected_total_bits} bits\\n\")\n",
" best_combination = (\n",
" estimated_time,\n",
" number_of_subdatabases,\n",
" database_input_bits,\n",
" database_output_bits,\n",
" )\n",
"\n",
" print(\n",
" f\"\\nBest combination: {best_combination[0]} seconds for a DB of {expected_total_bits} bits\\n\"\n",
" )\n",
"\n",
"\n",
"find_best_combination(30)\n",
"find_best_combination(20)"