Files
zk-stats-lib/tests/test_core.py
2024-02-03 14:33:17 +08:00

68 lines
2.7 KiB
Python

import torch
from zkstats.core import get_data_commitment_maps
from zkstats.computation import computation_to_model
from .helpers import data_to_file, compute
def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
data_path = tmp_path / "data.json"
# data_json is a mapping[column_name, column_data]
# {
# "columns_0": [1, 2, 3, 4, 5],
# "columns_1": [6, 7, 8, 9, 10],
# }
data_json = data_to_file(data_path, [column_0, column_1])
# commitment_maps is a mapping[scale -> mapping[column_name, commitment_hex]]
# {
# scale_0: {
# "columns_0": "0x...",
# "columns_1": "0x...",
# },
# scale_1: {
# "columns_0": "0x...",
# "columns_1": "0x...",
# }
# }
commitment_maps = get_data_commitment_maps(data_path, scales)
assert len(commitment_maps) == len(scales)
for scale, commitment_map in commitment_maps.items():
assert int(scale) in scales
assert len(commitment_map) == len(data_json)
for column_name, commitment_hex in commitment_map.items():
assert column_name in data_json
# Check if the commitment is a valid hex number
int(commitment_hex, 16)
def test_get_data_commitment_maps_hardcoded(tmp_path):
"""
This test is to check if the data commitment scheme doesn't change
"""
data_path = tmp_path / "data.json"
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
data_to_file(data_path, [column_0, column_1])
scales = [2, 3]
commitment_maps = get_data_commitment_maps(data_path, scales)
expected = {"2": {'columns_0': '0x28b5eeb5aeee399c8c50c5b323def9a1aec1deee5b9ae193463d4f9b8893a9a3', 'columns_1': '0x0523c85a86dddd810418e8376ce6d9d21b1b7363764c9c31b575b8ffbad82987'}, "3": {'columns_0': '0x0a2906522d3f902ff4a63ee8aed4d2eaec0b14f71c51eb9557bd693a4e7d77ad', 'columns_1': '0x2dac7fee1efb9eb955f52494a26a3fba6d1fa28cc819e598cb0af31a47b29d08'}}
assert commitment_maps == expected
def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
data_path = tmp_path / "data.json"
data_json = data_to_file(data_path, [column_0, column_1])
columns = list(data_json.keys())
assert len(columns) == 2
# Select only the first column from two columns
selected_columns = [columns[0]]
def simple_computation(state, x):
return state.mean(x[0])
_, model = computation_to_model(simple_computation, error)
# gen settings, setup, prove, verify
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)