mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
Merge branch 'tmp/recheck_func' of https://github.com/ZKStats/zk-stats-lib into tmp/recheck_func
This commit is contained in:
@@ -2,13 +2,14 @@ import csv
|
||||
from pathlib import Path
|
||||
from typing import Type, Sequence, Mapping, Union, Literal, Callable
|
||||
from enum import Enum
|
||||
import torch
|
||||
import ezkl
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import time
|
||||
|
||||
import torch
|
||||
import ezkl
|
||||
|
||||
from zkstats.computation import IModel
|
||||
|
||||
|
||||
@@ -277,11 +278,7 @@ def generate_data_commitment(data_path: str, scales: Sequence[int], data_commitm
|
||||
Generate and store data commitment maps for different scales so that verifiers can verify
|
||||
proofs with different scales.
|
||||
|
||||
:param data_path: path to the data file. The data file should be a JSON file with the following format:
|
||||
{
|
||||
"column_0": [number_0, number_1, ...],
|
||||
"column_1": [number_0, number_1, ...],
|
||||
}
|
||||
:param data_path: data file path. The format must be anything defined in `DataExtension`
|
||||
:param scales: a list of scales to use for the commitments
|
||||
:param data_commitment_path: path to store the generated data commitment maps
|
||||
"""
|
||||
@@ -421,6 +418,50 @@ def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_pat
|
||||
preprocess_function(data_path, out_data_json_path)
|
||||
|
||||
|
||||
def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
|
||||
data_csv_path = Path(old_file_path)
|
||||
with open(data_csv_path, 'r') as f_csv:
|
||||
reader = csv.reader(f_csv, delimiter=delimiter, strict=True)
|
||||
# Read all data from the reader to `rows`
|
||||
rows_with_column_name = tuple(reader)
|
||||
if len(rows_with_column_name) < 1:
|
||||
raise ValueError("No column names in the CSV file")
|
||||
if len(rows_with_column_name) < 2:
|
||||
raise ValueError("No data in the CSV file")
|
||||
column_names = rows_with_column_name[0]
|
||||
rows = rows_with_column_name[1:]
|
||||
|
||||
columns = [
|
||||
[
|
||||
float(rows[j][i])
|
||||
for j in range(len(rows))
|
||||
]
|
||||
for i in range(len(rows[0]))
|
||||
]
|
||||
data = {
|
||||
column_name: column_data
|
||||
for column_name, column_data in zip(column_names, columns)
|
||||
}
|
||||
with open(out_data_json_path, "w") as f_json:
|
||||
json.dump(data, f_json)
|
||||
|
||||
|
||||
class DataExtension(Enum):
|
||||
CSV = ".csv"
|
||||
JSON = ".json"
|
||||
|
||||
|
||||
DATA_FORMAT_PREPROCESSING_FUNCTION: dict[DataExtension, Callable[[Union[Path, str], Path], None]] = {
|
||||
DataExtension.CSV: _csv_file_to_json,
|
||||
DataExtension.JSON: lambda old_file_path, out_data_json_path: Path(out_data_json_path).write_text(Path(old_file_path).read_text())
|
||||
}
|
||||
|
||||
def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_path: Path):
|
||||
data_file_extension = DataExtension(data_path.suffix)
|
||||
preprocess_function = DATA_FORMAT_PREPROCESSING_FUNCTION[data_file_extension]
|
||||
preprocess_function(data_path, out_data_json_path)
|
||||
|
||||
|
||||
def _process_data(
|
||||
data_path: Union[str| Path],
|
||||
col_array: list[str],
|
||||
@@ -428,11 +469,8 @@ def _process_data(
|
||||
) -> list[torch.Tensor]:
|
||||
data_tensor_array=[]
|
||||
sel_data = []
|
||||
data_path: Path = Path(data_path)
|
||||
# Convert data file to json under the same directory but with suffix .json
|
||||
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
|
||||
_preprocess_data_file_to_json(data_path, data_json_path)
|
||||
data_onefile = json.loads(open(data_json_path, "r").read())
|
||||
data_onefile = json.loads(open(data_path, "r").read())
|
||||
|
||||
for col in col_array:
|
||||
data = data_onefile[col]
|
||||
data_tensor = torch.tensor(data, dtype = torch.float32)
|
||||
@@ -451,4 +489,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
|
||||
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
|
||||
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])
|
||||
|
||||
return res_poseidon_hash
|
||||
return res_poseidon_hash
|
||||
|
||||
Reference in New Issue
Block a user