mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
refine after merge
This commit is contained in:
@@ -87,8 +87,8 @@ def prover_gen_settings(
|
||||
"""
|
||||
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
|
||||
|
||||
# export onnx file
|
||||
_export_onnx(prover_model, data_tensor_array, prover_model_path)
|
||||
|
||||
# gen + calibrate setting
|
||||
_gen_settings(sel_data_path, prover_model_path, scale, mode, settings_path)
|
||||
|
||||
@@ -282,6 +282,7 @@ def generate_data_commitment(data_path: str, scales: Sequence[int], data_commitm
|
||||
:param scales: a list of scales to use for the commitments
|
||||
:param data_commitment_path: path to store the generated data commitment maps
|
||||
"""
|
||||
|
||||
# Convert `data_path` to json file `data_json_path`
|
||||
data_path: Path = Path(data_path)
|
||||
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
|
||||
@@ -354,7 +355,7 @@ def _gen_settings(
|
||||
# Poseidon is not homomorphic additive, maybe consider Pedersens or Dory commitment.
|
||||
gip_run_args = ezkl.PyRunArgs()
|
||||
gip_run_args.input_visibility = "hashed" # one commitment (values hashed) for each column
|
||||
gip_run_args.param_visibility = "fixed" # no parameters shown
|
||||
gip_run_args.param_visibility = "private" # no parameters shown
|
||||
gip_run_args.output_visibility = "public" # should be `(torch.Tensor(1.0), output)`
|
||||
|
||||
# generate settings
|
||||
@@ -374,49 +375,6 @@ def _gen_settings(
|
||||
print("scale: ", scale)
|
||||
print("setting: ", f_setting.read())
|
||||
|
||||
def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
|
||||
data_csv_path = Path(old_file_path)
|
||||
with open(data_csv_path, 'r') as f_csv:
|
||||
reader = csv.reader(f_csv, delimiter=delimiter, strict=True)
|
||||
# Read all data from the reader to `rows`
|
||||
rows_with_column_name = tuple(reader)
|
||||
if len(rows_with_column_name) < 1:
|
||||
raise ValueError("No column names in the CSV file")
|
||||
if len(rows_with_column_name) < 2:
|
||||
raise ValueError("No data in the CSV file")
|
||||
column_names = rows_with_column_name[0]
|
||||
rows = rows_with_column_name[1:]
|
||||
|
||||
columns = [
|
||||
[
|
||||
float(rows[j][i])
|
||||
for j in range(len(rows))
|
||||
]
|
||||
for i in range(len(rows[0]))
|
||||
]
|
||||
data = {
|
||||
column_name: column_data
|
||||
for column_name, column_data in zip(column_names, columns)
|
||||
}
|
||||
with open(out_data_json_path, "w") as f_json:
|
||||
json.dump(data, f_json)
|
||||
|
||||
|
||||
class DataExtension(Enum):
|
||||
CSV = ".csv"
|
||||
JSON = ".json"
|
||||
|
||||
|
||||
DATA_FORMAT_PREPROCESSING_FUNCTION: dict[DataExtension, Callable[[Union[Path, str], Path], None]] = {
|
||||
DataExtension.CSV: _csv_file_to_json,
|
||||
DataExtension.JSON: lambda old_file_path, out_data_json_path: Path(out_data_json_path).write_text(Path(old_file_path).read_text())
|
||||
}
|
||||
|
||||
def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_path: Path):
|
||||
data_file_extension = DataExtension(data_path.suffix)
|
||||
preprocess_function = DATA_FORMAT_PREPROCESSING_FUNCTION[data_file_extension]
|
||||
preprocess_function(data_path, out_data_json_path)
|
||||
|
||||
|
||||
def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
|
||||
data_csv_path = Path(old_file_path)
|
||||
@@ -463,13 +421,17 @@ def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_pat
|
||||
|
||||
|
||||
def _process_data(
|
||||
data_path: Union[str| Path],
|
||||
data_path: Union[str | Path],
|
||||
col_array: list[str],
|
||||
sel_data_path: list[str],
|
||||
) -> list[torch.Tensor]:
|
||||
data_tensor_array=[]
|
||||
sel_data = []
|
||||
data_onefile = json.loads(open(data_path, "r").read())
|
||||
data_path: Path = Path(data_path)
|
||||
# Convert data file to json under the same directory but with suffix .json
|
||||
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
|
||||
_preprocess_data_file_to_json(data_path, data_json_path)
|
||||
data_onefile = json.loads(open(data_json_path, "r").read())
|
||||
|
||||
for col in col_array:
|
||||
data = data_onefile[col]
|
||||
@@ -489,4 +451,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
|
||||
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
|
||||
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])
|
||||
|
||||
return res_poseidon_hash
|
||||
return res_poseidon_hash
|
||||
Reference in New Issue
Block a user