mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
import torch
|
|
|
|
from zkstats.core import get_data_commitment_maps
|
|
from zkstats.computation import computation_to_model
|
|
|
|
from .helpers import data_to_file, compute
|
|
|
|
|
|
def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
|
|
data_path = tmp_path / "data.json"
|
|
# data_json is a mapping[column_name, column_data]
|
|
# {
|
|
# "columns_0": [1, 2, 3, 4, 5],
|
|
# "columns_1": [6, 7, 8, 9, 10],
|
|
# }
|
|
data_json = data_to_file(data_path, [column_0, column_1])
|
|
# commitment_maps is a mapping[scale -> mapping[column_name, commitment_hex]]
|
|
# {
|
|
# scale_0: {
|
|
# "columns_0": "0x...",
|
|
# "columns_1": "0x...",
|
|
# },
|
|
# scale_1: {
|
|
# "columns_0": "0x...",
|
|
# "columns_1": "0x...",
|
|
# }
|
|
# }
|
|
commitment_maps = get_data_commitment_maps(data_path, scales)
|
|
|
|
assert len(commitment_maps) == len(scales)
|
|
for scale, commitment_map in commitment_maps.items():
|
|
assert int(scale) in scales
|
|
assert len(commitment_map) == len(data_json)
|
|
for column_name, commitment_hex in commitment_map.items():
|
|
assert column_name in data_json
|
|
# Check if the commitment is a valid hex number
|
|
int(commitment_hex, 16)
|
|
|
|
|
|
def test_get_data_commitment_maps_hardcoded(tmp_path):
|
|
"""
|
|
This test is to check if the data commitment scheme doesn't change
|
|
"""
|
|
data_path = tmp_path / "data.json"
|
|
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
|
|
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
|
|
data_to_file(data_path, [column_0, column_1])
|
|
scales = [2, 3]
|
|
commitment_maps = get_data_commitment_maps(data_path, scales)
|
|
expected = {"2": {'columns_0': '0x28b5eeb5aeee399c8c50c5b323def9a1aec1deee5b9ae193463d4f9b8893a9a3', 'columns_1': '0x0523c85a86dddd810418e8376ce6d9d21b1b7363764c9c31b575b8ffbad82987'}, "3": {'columns_0': '0x0a2906522d3f902ff4a63ee8aed4d2eaec0b14f71c51eb9557bd693a4e7d77ad', 'columns_1': '0x2dac7fee1efb9eb955f52494a26a3fba6d1fa28cc819e598cb0af31a47b29d08'}}
|
|
assert commitment_maps == expected
|
|
|
|
|
|
def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
|
|
data_path = tmp_path / "data.json"
|
|
data_json = data_to_file(data_path, [column_0, column_1])
|
|
columns = list(data_json.keys())
|
|
assert len(columns) == 2
|
|
# Select only the first column from two columns
|
|
selected_columns = [columns[0]]
|
|
|
|
def simple_computation(state, x):
|
|
return state.mean(x[0])
|
|
|
|
_, model = computation_to_model(simple_computation, error)
|
|
# gen settings, setup, prove, verify
|
|
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
|