add simple tests for core functions which consumes csv files

This commit is contained in:
mhchia
2024-04-25 22:26:56 +08:00
parent 7704bf6e6e
commit ae6277a44a

View File

@@ -2,7 +2,7 @@ import json
import torch
from zkstats.core import generate_data_commitment, _preprocess_data_file_to_json
from zkstats.core import generate_data_commitment, prover_gen_settings, _preprocess_data_file_to_json, verifier_define_calculation
from zkstats.computation import computation_to_model
from .helpers import data_to_json_file, compute
@@ -77,6 +77,44 @@ def test_integration_select_partial_columns(tmp_path, column_0, column_1, error,
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
def test_csv_data(tmp_path, column_0, column_1, error, scales):
data_json_path = tmp_path / "data.csv"
data_csv_path = tmp_path / "data.csv"
data_json = data_to_json_file(data_json_path, [column_0, column_1])
json_file_to_csv(data_json_path, data_csv_path)
selected_columns = list(data_json.keys())
def simple_computation(state, x):
return state.mean(x[0])
sel_data_path = tmp_path / "comb_data.json"
model_path = tmp_path / "model.onnx"
settings_path = tmp_path / "settings.json"
data_commitment_path = tmp_path / "commitments.json"
# Test: `generate_data_commitment` works with csv
generate_data_commitment(data_csv_path, scales, data_commitment_path)
# Test: `prover_gen_settings` works with csv
_, model_for_proving = computation_to_model(simple_computation, error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
sel_data_path=str(sel_data_path),
prover_model=model_for_proving,
prover_model_path=str(model_path),
scale=scales,
mode="resources",
settings_path=str(settings_path),
)
# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
_, model_for_verification = computation_to_model(simple_computation, error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
def json_file_to_csv(data_json_path, data_csv_path):
with open(data_json_path, "r") as f:
data_from_json = json.load(f)