Merge pull request #57 from ZKStats/auto_gen_selected_columns

Detect selected columns in `computation_to_model`
This commit is contained in:
Kevin Mai-Husan Chia
2024-09-13 16:18:15 +08:00
committed by GitHub
8 changed files with 1035 additions and 832 deletions

View File

@@ -30,22 +30,9 @@ poetry install
### Define Your Computation
User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, list[torch.Tensor]], torch.Tensor]`:
```python
import torch
from zkstats.computation import State
# User-defined computation
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
# Define your computation here
...
```
User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, Args], torch.Tensor]`:
- first argument is a `State` object, which contains the statistical functions that ZKStats supports.
- second argument is a list of PyTorch tensors, the input data. `data[0]` is the first column, `data[1]` is the second column, and so on.
- second argument is a `Args` object, which is a dictionary of PyTorch tensors, the input data. `Args['column1']` is the first column, `Args['column2']` is the second column, and so on.
For example, we have two columns of data and we want to compute the mean of the medians of the two columns:
@@ -116,9 +103,9 @@ Note here, that we can also just let prover generate model, and then send that m
```python
from zkstats.core import computation_to_model
# For prover: generate prover_model, and write to precal_witness file
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, selected_columns, error)
selected_columns, _, prover_model = computation_to_model(user_computation, precal_witness_path, data_shape, True, error)
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, selected_columns, error)
selected_columns, _, verifier_model = computation_to_model(user_computation, precal_witness_path, data_shape, False, error)
```
#### Data Provider: generate settings

1551
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,8 @@ authors = ["Jern Kunpittaya", "Kevin Chia"]
[tool.poetry.dependencies]
python = "^3.9"
ezkl = "9.1.0"
torch = "^2.1.1"
# fix torch version to 2.2.0 due to a weird issue when upgrading to 2.4.1
torch = "2.2.0"
requests = "^2.31.0"
scipy = "^1.11.4"
numpy = "^1.26.2"

View File

@@ -16,23 +16,21 @@ ERROR_CIRCUIT_STRICT = 0.0001
ERROR_CIRCUIT_RELAXED = 0.1
def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
column_names = [f"columns_{i}" for i in range(len(data))]
def data_to_json_file(data_path: Path, data: dict[str, torch.Tensor]) -> dict[str, list]:
column_to_data = {
column: d.tolist()
for column, d in zip(column_names, data)
for column, d in data.items()
}
print('columnnnn: ', column_to_data)
with open(data_path, "w") as f:
json.dump(column_to_data, f)
return column_to_data
def compute_model(
def compute(
basepath: Path,
data: list[torch.Tensor],
model: IModel,
data: dict[str, torch.Tensor],
model: Type[IModel],
# computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
):
@@ -47,10 +45,10 @@ def compute_model(
data_path = basepath / "data.json"
data_commitment_path = basepath / "commitments.json"
column_to_data = data_to_json_file(data_path, data)
data_to_json_file(data_path, data)
# If selected_columns_params is None, select all columns
if selected_columns_params is None:
selected_columns = list(column_to_data.keys())
selected_columns = list(data.keys())
else:
selected_columns = selected_columns_params
@@ -62,44 +60,17 @@ def compute_model(
else:
scales = scales_params
scales_for_commitments = scales_params
# create_dummy((data_path), (dummy_data_path))
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, selected_columns, error)
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
# No need, since verifier & prover share the same onnx
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False, selected_columns, error)
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
# print('slett col: ', selected_columns)
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
def compute(
basepath: Path,
data: list[torch.Tensor],
computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
) -> State:
data_path = basepath / "data.json"
precal_witness_path = basepath / "precal_witness_path.json"
column_to_data = data_to_json_file(data_path, data)
# If selected_columns_params is None, select all columns
if selected_columns_params is None:
selected_columns = list(column_to_data.keys())
else:
selected_columns = selected_columns_params
state, model = computation_to_model(computation, precal_witness_path, True, selected_columns)
compute_model(basepath, data, model, scales_params, selected_columns_params)
return state
# Error tolerance between zkstats python implementation and python statistics module

View File

@@ -4,7 +4,7 @@ import torch
import pytest
from zkstats.computation import State, Args, computation_to_model
from zkstats.computation import State, computation_to_model, analyze_computation, TComputation, Args
from zkstats.ops import (
Mean,
Median,
@@ -25,9 +25,9 @@ from .helpers import assert_result, compute, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUI
def nested_computation(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
z = args['columns_2']
x = args["x"]
y = args["y"]
z = args["z"]
out_0 = state.median(x)
out_1 = state.geometric_mean(y)
out_2 = state.harmonic_mean(x)
@@ -63,8 +63,14 @@ def nested_computation(state: State, args: Args):
[ERROR_CIRCUIT_DEFAULT],
)
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
precal_witness_path = tmp_path / "precal_witness_path.json"
x, y, z = column_0, column_1, column_2
state = compute(tmp_path, [x, y, z], nested_computation, scales)
data_shape = {"x": len(x), "y": len(y), "z": len(z)}
data = {"x": x, "y": y, "z": z}
selected_columns, state, model = computation_to_model(nested_computation, precal_witness_path, data_shape, True, error)
compute(tmp_path, data, model, scales, selected_columns)
# There are 11 ops in the computation
assert state.current_op_index == 12
ops = state.ops
@@ -152,10 +158,14 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
def condition(_x: torch.Tensor):
return _x < 4
def where_and_op(state: State, args: Args):
x = args['columns_0']
column_name = "x"
def where_and_op(state, args):
x = args[column_name]
return op_type(state, state.where(condition(x), x))
state = compute(tmp_path, [column], where_and_op, scales)
precal_witness_path = tmp_path / "precal_witness_path.json"
_, state, model = computation_to_model(where_and_op, precal_witness_path, {column_name: column.shape}, True, error)
compute(tmp_path, {column_name: column}, model, scales)
res_op = state.ops[-1]
filtered = column[condition(column)]
@@ -174,14 +184,18 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
def condition_0(_x: torch.Tensor):
return _x > 4
def where_and_op(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
def where_and_op(state: State, args: list[torch.Tensor]):
x = args["x"]
y = args["y"]
condition_x = condition_0(x)
filtered_x = state.where(condition_x, x)
filtered_y = state.where(condition_x, y)
return op_type(state, filtered_x, filtered_y)
state = compute(tmp_path, [column_0, column_1], where_and_op, scales)
precal_witness_path = tmp_path / "precal_witness_path.json"
data_shape = {"x": len(column_0), "y": len(column_1)}
data = {"x": column_0, "y": column_1}
selected_columns, state, model = computation_to_model(where_and_op, precal_witness_path, data_shape, True ,error)
compute(tmp_path, data, model, scales, selected_columns)
res_op = state.ops[-1]
condition_x = condition_0(column_0)
@@ -189,3 +203,68 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
filtered_y = column_1[condition_x]
expected_res = expected_func(filtered_x.tolist(), filtered_y.tolist())
assert_result(res_op.result.data, expected_res)
def test_analyze_computation_success():
def valid_computation(state, args):
x = args["column1"]
y = args["column2"]
return state.mean(x) + state.median(y)
result = analyze_computation(valid_computation)
assert set(result) == {"column1", "column2"}
def test_analyze_computation_no_columns():
def no_columns_computation(state, args):
return state.mean(state.median([1, 2, 3]))
result = analyze_computation(no_columns_computation)
assert result == []
def test_analyze_computation_multiple_uses():
def multiple_uses_computation(state, args):
x = args["column1"]
y = args["column2"]
z = args["column1"] # Using column1 twice
return state.mean(x) + state.median(y) + state.sum(z)
result = analyze_computation(multiple_uses_computation)
assert set(result) == {"column1", "column2"}
def test_analyze_computation_nested_args():
def nested_args_computation(state, args):
x = args["column1"]["nested"]
y = args["column2"]
return state.mean(x) + state.median(y)
result = analyze_computation(nested_args_computation)
assert set(result) == {"column1", "column2"}
def test_analyze_computation_invalid_params():
def invalid_params_computation(invalid1, invalid2):
return invalid1.mean(invalid2["column"])
with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
analyze_computation(invalid_params_computation)
def test_analyze_computation_wrong_param_names():
def wrong_param_names(state, wrong_name):
return state.mean(wrong_name["column"])
with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
analyze_computation(wrong_param_names)
def test_analyze_computation_dynamic_column_access():
def dynamic_column_access(state, args):
columns = ["column1", "column2"]
return sum(state.mean(args[col]) for col in columns)
# This won't catch dynamically accessed columns
result = analyze_computation(dynamic_column_access)
assert result == []
def test_analyze_computation_lambda():
lambda_computation = lambda state, args: state.mean(args["column"])
with pytest.raises(ValueError, match="Lambda functions are not supported in analyze_computation"):
analyze_computation(lambda_computation)

View File

@@ -16,7 +16,8 @@ def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
# "columns_0": [1, 2, 3, 4, 5],
# "columns_1": [6, 7, 8, 9, 10],
# }
data_json = data_to_json_file(data_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
# data_commitment is a mapping[scale -> mapping[column_name, commitment_hex]]
# {
# scale_0: {
@@ -51,7 +52,8 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):
data_commitment_path = tmp_path / "commitments.json"
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
data_to_json_file(data_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
scales = [2, 3]
generate_data_commitment(data_path, scales, data_commitment_path)
with open(data_commitment_path, "r") as f:
@@ -63,30 +65,28 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):
def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
data_path = tmp_path / "data.json"
data_json = data_to_json_file(data_path, [column_0, column_1])
columns = list(data_json.keys())
assert len(columns) == 2
# Select only the first column from two columns
selected_columns = [columns[0]]
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_path, data_json)
def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
return state.mean(args["columns_0"])
precal_witness_path = tmp_path / "precal_witness_path.json"
selected_columns, _, model = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
# gen settings, setup, prove, verify
compute(tmp_path, [column_0, column_1], simple_computation, scales, selected_columns)
compute(tmp_path, data_json, model, scales, selected_columns)
def test_csv_data(tmp_path, column_0, column_1, error, scales):
data_json_path = tmp_path / "data.json"
data_csv_path = tmp_path / "data.csv"
data_json = data_to_json_file(data_json_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_json_path, data_json)
json_file_to_csv(data_json_path, data_csv_path)
selected_columns = list(data_json.keys())
def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
return state.mean(args["columns_0"])
sel_data_path = tmp_path / "comb_data.json"
model_path = tmp_path / "model.onnx"
@@ -98,7 +98,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
generate_data_commitment(data_csv_path, scales, data_commitment_path)
# Test: `prover_gen_settings` works with csv
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True, selected_columns, error)
selected_columns, _, model_for_proving = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
@@ -112,7 +112,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False, selected_columns, error)
selected_columns, _, model_for_verification = computation_to_model(simple_computation, precal_witness_path, data_shape, False, error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
def json_file_to_csv(data_json_path, data_csv_path):
@@ -135,7 +135,8 @@ def json_file_to_csv(data_json_path, data_csv_path):
def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
data_json_path = tmp_path / "data.json"
data_from_json = data_to_json_file(data_json_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_from_json = data_to_json_file(data_json_path, data_json)
# Test: csv can be converted to json
# 1. Generate a csv file from json

View File

@@ -5,9 +5,9 @@ import pytest
import torch
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
from zkstats.computation import IModel, IsResultPrecise
from .helpers import compute_model, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
@pytest.mark.parametrize(
@@ -49,8 +49,8 @@ def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tens
)
def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, scales: list[float]):
expected_res = statistics.linear_regression(column_0.tolist(), column_1.tolist())
columns = [column_0, column_1]
regression = Regression.create(columns, error)
columns = {"columns_0": column_0, "columns_1": column_1}
regression = Regression.create(list(columns.values()), error)
# shape = [2, 1]
actual_res = regression.result
assert_result(expected_res.slope, actual_res[0][0])
@@ -58,8 +58,7 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
class Model(IModel):
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return regression.ezkl(x), regression.result
compute_model(tmp_path, columns, Model, scales)
compute(tmp_path, columns, Model, scales)
def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[list[float]], float], error: float, scales: list[float], columns: list[torch.Tensor]):
@@ -70,4 +69,5 @@ def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[li
class Model(IModel):
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return op.ezkl(x), op.result
compute_model(tmp_path, columns, Model, scales)
data = {f"columns_{i}": column for i, column in enumerate(columns)}
compute(tmp_path, data, Model, scales)

View File

@@ -264,30 +264,81 @@ class IModel(nn.Module):
# An computation function. Example:
# def computation(state: State, x: list[torch.Tensor]):
# out_0 = state.median(x[0])
# out_1 = state.median(x[1])
# def computation(state: State, args: dict[str, torch.Tensor]):
# out_0 = state.median(args["x"])
# out_1 = state.median(args["y"])
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
TComputation = Callable[[State, dict[str, torch.Tensor]], torch.Tensor]
import ast
import inspect
def analyze_computation(computation: TComputation):
source = inspect.getsource(computation)
# Check if it's a lambda function
if source.strip().startswith('lambda'):
raise ValueError("Lambda functions are not supported in analyze_computation. Please use a regular function definition instead.")
# Existing code for regular functions
# Correct indentation
lines = source.splitlines()
min_indent = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
corrected_source = '\n'.join(line[min_indent:] for line in lines)
tree = ast.parse(corrected_source)
column_names = set()
class ComputationVisitor(ast.NodeVisitor):
def __init__(self):
self.valid_params = False
def visit_FunctionDef(self, node):
if len(node.args.args) == 2:
if node.args.args[0].arg == 'state' and node.args.args[1].arg == 'args':
self.valid_params = True
self.generic_visit(node)
def visit_Subscript(self, node):
if isinstance(node.value, ast.Name) and node.value.id == 'args':
if isinstance(node.slice, ast.Constant):
column_names.add(node.slice.value)
self.generic_visit(node)
visitor = ComputationVisitor()
visitor.visit(tree)
if not visitor.valid_params:
raise ValueError("The computation function must have two parameters named 'state' and 'args'")
return list(column_names)
class Args:
def __init__(
self,
columns: list[str],
data_shape: dict[str, int],
data: list[torch.Tensor],
):
if len(columns) != len(data):
raise ValueError("columns and data must have the same length")
column_names = list(data_shape.keys())[:len(data)]
self.data_dict = {
column_name: d
for column_name, d in zip(columns, data)
for column_name, d in zip(column_names, data)
}
def __getitem__(self, key: str) -> torch.Tensor:
return self.data_dict[key]
def computation_to_model(computation: TComputation, precal_witness_path: str, isProver:bool, selected_columns: list[str], error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
def computation_to_model(
computation: TComputation,
precal_witness_path:str,
data_shape: dict[str, int],
isProver: bool,
error: float = DEFAULT_ERROR,
) -> tuple[list[str], State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
@@ -295,6 +346,16 @@ def computation_to_model(computation: TComputation, precal_witness_path: str, is
:return: A tuple of State and Model. The Model is a torch model that can be used for exporting to onnx.
State is a container for intermediate results of computation, which can be useful when debugging.
"""
selected_columns_unordered = analyze_computation(computation)
# Preserve the order from data_shape
selected_columns = [col for col in data_shape.keys() if col in selected_columns_unordered]
assert len(selected_columns) == len(selected_columns_unordered), "Selected columns must match"
invalid_columns = set(selected_columns) - set(data_shape.keys())
if invalid_columns:
raise ValueError(f"Computation uses columns not present in data: {invalid_columns}")
state = State(error)
state.precal_witness_path = precal_witness_path
@@ -302,25 +363,17 @@ def computation_to_model(computation: TComputation, precal_witness_path: str, is
class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
"""
Calculate the witnesses of the computation and store them in the state.
"""
# In the preprocess step, the operations are calculated and the results are stored in the state.
# So we don't need to get the returned result
args = Args(selected_columns, x)
args = Args(data_shape, x)
computation(state, args)
state.set_ready_for_exporting_onnx()
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
"""
Called by torch.onnx.export.
"""
args = Args(selected_columns, x)
args = Args(data_shape, x)
result = computation(state, args)
is_computation_result_accurate = state.bools[0]()
for op_precise_check in state.bools[1:]:
is_op_result_accurate = op_precise_check()
is_computation_result_accurate = torch.logical_and(is_computation_result_accurate, is_op_result_accurate)
return is_computation_result_accurate, result
return state, Model
return selected_columns, state, Model