Files
zk-stats-lib/tests/test_core.py
2024-09-13 17:35:16 +08:00

160 lines
7.1 KiB
Python

import json
import torch
from zkstats.core import generate_data_commitment, prover_gen_settings, _preprocess_data_file_to_json, verifier_define_calculation
from zkstats.computation import computation_to_model
from .helpers import data_to_json_file, compute
def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
data_path = tmp_path / "data.json"
data_commitment_path = tmp_path / "commitments.json"
# data_json is a mapping[column_name, column_data]
# {
# "columns_0": [1, 2, 3, 4, 5],
# "columns_1": [6, 7, 8, 9, 10],
# }
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
# data_commitment is a mapping[scale -> mapping[column_name, commitment_hex]]
# {
# scale_0: {
# "columns_0": "0x...",
# "columns_1": "0x...",
# },
# scale_1: {
# "columns_0": "0x...",
# "columns_1": "0x...",
# }
# }
generate_data_commitment(data_path, scales, data_commitment_path)
with open(data_commitment_path, "r") as f:
data_commitment = json.load(f)
assert len(data_commitment) == len(scales)
for scale, commitment_map in data_commitment.items():
assert int(scale) in scales
assert len(commitment_map) == len(data_json)
for column_name, commitment_hex in commitment_map.items():
assert column_name in data_json
# Check if the commitment is a valid hex number
int(commitment_hex, 16)
def test_get_data_commitment_maps_hardcoded(tmp_path):
"""
This test is to check if the data commitment scheme doesn't change
"""
data_path = tmp_path / "data.json"
data_commitment_path = tmp_path / "commitments.json"
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
scales = [2, 3]
generate_data_commitment(data_path, scales, data_commitment_path)
with open(data_commitment_path, "r") as f:
data_commitment = json.load(f)
# expected = {"2": {'columns_0': '0x28b5eeb5aeee399c8c50c5b323def9a1aec1deee5b9ae193463d4f9b8893a9a3', 'columns_1': '0x0523c85a86dddd810418e8376ce6d9d21b1b7363764c9c31b575b8ffbad82987'}, "3": {'columns_0': '0x0a2906522d3f902ff4a63ee8aed4d2eaec0b14f71c51eb9557bd693a4e7d77ad', 'columns_1': '0x2dac7fee1efb9eb955f52494a26a3fba6d1fa28cc819e598cb0af31a47b29d08'}}
expected = {"2": {'columns_0': 'a3a993889b4f3d4693e19a5beedec1aea1f9de23b3c5508c9c39eeaeb5eeb528', 'columns_1': '8729d8baffb875b5319c4c7663731b1bd2d9e66c37e8180481dddd865ac82305'}, "3": {'columns_0': 'ad777d4e3a69bd5795eb511cf7140becead2d4aee83ea6f42f903f2d5206290a', 'columns_1': '089db2471af30acb98e519c88ca21f6dba3f6aa29424f555b99efb1eee7fac2d'}}
assert data_commitment == expected
def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
data_path = tmp_path / "data.json"
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_path, data_json)
def simple_computation(state, args):
m_0 = state.mean(args["columns_0"])
m_1 = state.mean(args["columns_1"])
return m_0, m_1
precal_witness_path = tmp_path / "precal_witness_path.json"
selected_columns, _, model = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
# gen settings, setup, prove, verify
compute(tmp_path, data_json, model, scales, selected_columns)
def test_csv_data(tmp_path, column_0, column_1, error, scales):
data_json_path = tmp_path / "data.json"
data_csv_path = tmp_path / "data.csv"
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_json_path, data_json)
json_file_to_csv(data_json_path, data_csv_path)
def simple_computation(state, args):
return state.mean(args["columns_0"])
sel_data_path = tmp_path / "comb_data.json"
model_path = tmp_path / "model.onnx"
settings_path = tmp_path / "settings.json"
data_commitment_path = tmp_path / "commitments.json"
precal_witness_path = tmp_path / "precal_witness.json"
# Test: `generate_data_commitment` works with csv
generate_data_commitment(data_csv_path, scales, data_commitment_path)
# Test: `prover_gen_settings` works with csv
selected_columns, _, model_for_proving = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
sel_data_path=str(sel_data_path),
prover_model=model_for_proving,
prover_model_path=str(model_path),
scale=scales,
mode="resources",
settings_path=str(settings_path),
)
# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
selected_columns, _, model_for_verification = computation_to_model(simple_computation, precal_witness_path, data_shape, False, error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
def json_file_to_csv(data_json_path, data_csv_path):
with open(data_json_path, "r") as f:
data_from_json = json.load(f)
# Generate csv file from json
column_names = list(data_from_json.keys())
len_columns = len(data_from_json[column_names[0]])
for column in column_names:
assert len(data_from_json[column]) == len_columns, "All columns should have the same length"
rows = [
[str(data_from_json[column][i]) for column in column_names]
for i in range(len_columns)
]
with open(data_csv_path, "w") as f:
f.write(",".join(column_names) + "\n")
for row in rows:
f.write(",".join(row) + "\n")
def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
data_json_path = tmp_path / "data.json"
data_json = {"columns_0": column_0, "columns_1": column_1}
data_from_json = data_to_json_file(data_json_path, data_json)
# Test: csv can be converted to json
# 1. Generate a csv file from json
data_csv_path = tmp_path / "data.csv"
json_file_to_csv(data_json_path, data_csv_path)
# 2. Convert csv to json
data_from_csv_json_path = tmp_path / "data_from_csv.json"
_preprocess_data_file_to_json(data_csv_path, data_from_csv_json_path)
with open(data_from_csv_json_path, "r") as f:
data_from_csv = json.load(f)
# 3. Compare the two json files
assert data_from_csv == data_from_json
# Test: this function can also handle json format by just copying the file
new_data_json_path = tmp_path / "new_data.json"
_preprocess_data_file_to_json(data_json_path, new_data_json_path)
with open(new_data_json_path, "r") as f:
new_data_from_json = json.load(f)
assert new_data_from_json == data_from_json