Files
concrete/frontends/concrete-python/examples/tfhers-ml/example.py

177 lines
5.3 KiB
Python

import os
import typing
from functools import partial
import click
import numpy as np
import concrete.fhe as fhe
from concrete.fhe import tfhers
### Options ###########################
# These parameters were saved by running the tfhers_utils utility:
# tfhers_utils save-params tfhers_params.json
TFHERS_PARAMS_FILE = "tfhers_params.json"
FHEUINT_PRECISION = 8
IS_SIGNED = True
#######################################
tfhers_type = tfhers.get_type_from_params(
TFHERS_PARAMS_FILE,
is_signed=IS_SIGNED,
precision=FHEUINT_PRECISION,
)
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type)
#### Model Parameters ##################
q_weights = np.array([[-25, 21, -10], [42, -20, -37], [-128, -15, 127], [-58, -51, 94]])
q_bias = np.array([[35167, 9417, -44584]])
weight_quantizer_zero_point = -5
########################################
rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs
@typing.no_type_check
def ml_inference(q_X: np.ndarray) -> np.ndarray:
y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True)
y_pred += q_bias
y_pred = fhe.round_bit_pattern(y_pred, rounder)
y_pred = y_pred >> rounder.lsbs_to_remove
return y_pred
def compute(tfhers_x):
####### TFHE-rs to Concrete #########
# x and y are supposed to be TFHE-rs values.
# to_native will use type information from x and y to do
# a correct conversion from TFHE-rs to Concrete
concrete_x = tfhers.to_native(tfhers_x)
####### TFHE-rs to Concrete #########
####### Concrete Computation ########
concrete_res = ml_inference(concrete_x)
####### Concrete Computation ########
####### Concrete to TFHE-rs #########
tfhers_res = tfhers.from_native(
concrete_res, tfhers_type
) # we have to specify the type we want to convert to
####### Concrete to TFHE-rs #########
return tfhers_res
def ccompilee():
compiler = fhe.Compiler(
compute,
{
"tfhers_x": "encrypted",
},
)
inputset = [
(
tfhers_int(
np.array(
[
[36, -17, -85, -124],
[29, -33, -85, -124],
[23, -26, -88, -124],
[19, -30, -82, -124],
[32, -13, -85, -124],
]
)
),
)
]
# Add the auto-adjustment before compilation
fhe.AutoRounder.adjust(compute, inputset)
# Print the number of bits rounded
print(f"lsbs_to_remove: {rounder.lsbs_to_remove}")
circuit = compiler.compile(inputset)
tfhers_bridge = tfhers.new_bridge(circuit)
return circuit, tfhers_bridge
@click.group()
def cli():
pass
@cli.command()
@click.option("-s", "--secret-key", type=str, required=False)
@click.option("-o", "--output-secret-key", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def keygen(output_secret_key: str, secret_key: str, concrete_keyset_path: str):
"""Concrete Key Generation"""
circuit, tfhers_bridge = ccompilee()
if os.path.exists(concrete_keyset_path):
print(f"removing old keyset at '{concrete_keyset_path}'")
os.remove(concrete_keyset_path)
if secret_key:
print(f"partial keygen from sk at '{secret_key}'")
# load the initial secret key to use for keygen
with open(
secret_key,
"rb",
) as f:
buff = f.read()
input_idx_to_key = {0: buff}
tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key)
else:
print("full keygen")
circuit.keygen()
print("saving Concrete Evaluation Keys")
with open(concrete_keyset_path, "wb") as f:
f.write(circuit.client.evaluation_keys.serialize())
print(f"saved Concrete Evaluation Keys to '{concrete_keyset_path}'")
sk: bytes = tfhers_bridge.serialize_input_secret_key(input_idx=0)
print(f"writing secret key of size {len(sk)} to '{output_secret_key}'")
with open(output_secret_key, "wb") as f:
f.write(sk)
@cli.command()
@click.option("-c", "--rust-ct", type=str, required=True)
@click.option("-o", "--output-rust-ct", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str):
"""Run circuit"""
circuit, tfhers_bridge = ccompilee()
if not os.path.exists(concrete_keyset_path):
raise RuntimeError("cannot find keys, you should run keygen before")
print(f"loading keys from '{concrete_keyset_path}'")
with open(concrete_keyset_path, "rb") as f:
eval_keys = fhe.EvaluationKeys.deserialize(f.read())
# read tfhers int from file
with open(rust_ct, "rb") as f:
buff = f.read()
# import fheuint8 and get its description
tfhers_uint8_x = tfhers_bridge.import_value(buff, input_idx=0)
print("Homomorphic evaluation...")
encrypted_result = circuit.server.run(tfhers_uint8_x, evaluation_keys=eval_keys)
print("exporting Rust ciphertexts")
# export fheuint8
buff = tfhers_bridge.export_value(encrypted_result, output_idx=0)
# write it to file
with open(output_rust_ct, "wb") as f:
f.write(buff)
if __name__ == "__main__":
cli()