Merge branch 'tmp/recheck_func' of https://github.com/ZKStats/zk-stats-lib into tmp/recheck_func

This commit is contained in:
JernKunpittaya
2024-05-10 18:33:50 +07:00

View File

@@ -2,13 +2,14 @@ import csv
from pathlib import Path
from typing import Type, Sequence, Mapping, Union, Literal, Callable
from enum import Enum
import torch
import ezkl
import os
import numpy as np
import json
import time
import torch
import ezkl
from zkstats.computation import IModel
@@ -277,11 +278,7 @@ def generate_data_commitment(data_path: str, scales: Sequence[int], data_commitm
Generate and store data commitment maps for different scales so that verifiers can verify
proofs with different scales.
:param data_path: path to the data file. The data file should be a JSON file with the following format:
{
"column_0": [number_0, number_1, ...],
"column_1": [number_0, number_1, ...],
}
:param data_path: data file path. The format must be anything defined in `DataExtension`
:param scales: a list of scales to use for the commitments
:param data_commitment_path: path to store the generated data commitment maps
"""
@@ -421,6 +418,50 @@ def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_pat
preprocess_function(data_path, out_data_json_path)
def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
data_csv_path = Path(old_file_path)
with open(data_csv_path, 'r') as f_csv:
reader = csv.reader(f_csv, delimiter=delimiter, strict=True)
# Read all data from the reader to `rows`
rows_with_column_name = tuple(reader)
if len(rows_with_column_name) < 1:
raise ValueError("No column names in the CSV file")
if len(rows_with_column_name) < 2:
raise ValueError("No data in the CSV file")
column_names = rows_with_column_name[0]
rows = rows_with_column_name[1:]
columns = [
[
float(rows[j][i])
for j in range(len(rows))
]
for i in range(len(rows[0]))
]
data = {
column_name: column_data
for column_name, column_data in zip(column_names, columns)
}
with open(out_data_json_path, "w") as f_json:
json.dump(data, f_json)
class DataExtension(Enum):
CSV = ".csv"
JSON = ".json"
DATA_FORMAT_PREPROCESSING_FUNCTION: dict[DataExtension, Callable[[Union[Path, str], Path], None]] = {
DataExtension.CSV: _csv_file_to_json,
DataExtension.JSON: lambda old_file_path, out_data_json_path: Path(out_data_json_path).write_text(Path(old_file_path).read_text())
}
def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_path: Path):
data_file_extension = DataExtension(data_path.suffix)
preprocess_function = DATA_FORMAT_PREPROCESSING_FUNCTION[data_file_extension]
preprocess_function(data_path, out_data_json_path)
def _process_data(
data_path: Union[str| Path],
col_array: list[str],
@@ -428,11 +469,8 @@ def _process_data(
) -> list[torch.Tensor]:
data_tensor_array=[]
sel_data = []
data_path: Path = Path(data_path)
# Convert data file to json under the same directory but with suffix .json
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
_preprocess_data_file_to_json(data_path, data_json_path)
data_onefile = json.loads(open(data_json_path, "r").read())
data_onefile = json.loads(open(data_path, "r").read())
for col in col_array:
data = data_onefile[col]
data_tensor = torch.tensor(data, dtype = torch.float32)
@@ -451,4 +489,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])
return res_poseidon_hash
return res_poseidon_hash