mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
Merge pull request #57 from ZKStats/auto_gen_selected_columns
Detect selected columns in `computation_to_model`
This commit is contained in:
21
README.md
21
README.md
@@ -30,22 +30,9 @@ poetry install
|
||||
|
||||
### Define Your Computation
|
||||
|
||||
User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, list[torch.Tensor]], torch.Tensor]`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from zkstats.computation import State
|
||||
|
||||
# User-defined computation
|
||||
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
# Define your computation here
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, Args], torch.Tensor]`:
|
||||
- first argument is a `State` object, which contains the statistical functions that ZKStats supports.
|
||||
- second argument is a list of PyTorch tensors, the input data. `data[0]` is the first column, `data[1]` is the second column, and so on.
|
||||
- second argument is a `Args` object, which is a dictionary of PyTorch tensors, the input data. `Args['column1']` is the first column, `Args['column2']` is the second column, and so on.
|
||||
|
||||
For example, we have two columns of data and we want to compute the mean of the medians of the two columns:
|
||||
|
||||
@@ -116,9 +103,9 @@ Note here, that we can also just let prover generate model, and then send that m
|
||||
```python
|
||||
from zkstats.core import computation_to_model
|
||||
# For prover: generate prover_model, and write to precal_witness file
|
||||
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, selected_columns, error)
|
||||
selected_columns, _, prover_model = computation_to_model(user_computation, precal_witness_path, data_shape, True, error)
|
||||
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
|
||||
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, selected_columns, error)
|
||||
selected_columns, _, verifier_model = computation_to_model(user_computation, precal_witness_path, data_shape, False, error)
|
||||
```
|
||||
|
||||
#### Data Provider: generate settings
|
||||
|
||||
1551
poetry.lock
generated
1551
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,8 @@ authors = ["Jern Kunpittaya", "Kevin Chia"]
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
ezkl = "9.1.0"
|
||||
torch = "^2.1.1"
|
||||
# fix torch version to 2.2.0 due to a weird issue when upgrading to 2.4.1
|
||||
torch = "2.2.0"
|
||||
requests = "^2.31.0"
|
||||
scipy = "^1.11.4"
|
||||
numpy = "^1.26.2"
|
||||
|
||||
@@ -16,23 +16,21 @@ ERROR_CIRCUIT_STRICT = 0.0001
|
||||
ERROR_CIRCUIT_RELAXED = 0.1
|
||||
|
||||
|
||||
def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
|
||||
column_names = [f"columns_{i}" for i in range(len(data))]
|
||||
def data_to_json_file(data_path: Path, data: dict[str, torch.Tensor]) -> dict[str, list]:
|
||||
column_to_data = {
|
||||
column: d.tolist()
|
||||
for column, d in zip(column_names, data)
|
||||
for column, d in data.items()
|
||||
}
|
||||
print('columnnnn: ', column_to_data)
|
||||
with open(data_path, "w") as f:
|
||||
json.dump(column_to_data, f)
|
||||
return column_to_data
|
||||
|
||||
|
||||
|
||||
def compute_model(
|
||||
def compute(
|
||||
basepath: Path,
|
||||
data: list[torch.Tensor],
|
||||
model: IModel,
|
||||
data: dict[str, torch.Tensor],
|
||||
model: Type[IModel],
|
||||
# computation: TComputation,
|
||||
scales_params: Optional[Sequence[int]] = None,
|
||||
selected_columns_params: Optional[list[str]] = None,
|
||||
):
|
||||
@@ -47,10 +45,10 @@ def compute_model(
|
||||
data_path = basepath / "data.json"
|
||||
data_commitment_path = basepath / "commitments.json"
|
||||
|
||||
column_to_data = data_to_json_file(data_path, data)
|
||||
data_to_json_file(data_path, data)
|
||||
# If selected_columns_params is None, select all columns
|
||||
if selected_columns_params is None:
|
||||
selected_columns = list(column_to_data.keys())
|
||||
selected_columns = list(data.keys())
|
||||
else:
|
||||
selected_columns = selected_columns_params
|
||||
|
||||
@@ -62,44 +60,17 @@ def compute_model(
|
||||
else:
|
||||
scales = scales_params
|
||||
scales_for_commitments = scales_params
|
||||
# create_dummy((data_path), (dummy_data_path))
|
||||
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
|
||||
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, selected_columns, error)
|
||||
|
||||
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
|
||||
|
||||
# No need, since verifier & prover share the same onnx
|
||||
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False, selected_columns, error)
|
||||
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
|
||||
|
||||
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
|
||||
|
||||
prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
|
||||
# print('slett col: ', selected_columns)
|
||||
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
|
||||
|
||||
|
||||
def compute(
|
||||
basepath: Path,
|
||||
data: list[torch.Tensor],
|
||||
computation: TComputation,
|
||||
scales_params: Optional[Sequence[int]] = None,
|
||||
selected_columns_params: Optional[list[str]] = None,
|
||||
) -> State:
|
||||
data_path = basepath / "data.json"
|
||||
precal_witness_path = basepath / "precal_witness_path.json"
|
||||
|
||||
column_to_data = data_to_json_file(data_path, data)
|
||||
# If selected_columns_params is None, select all columns
|
||||
if selected_columns_params is None:
|
||||
selected_columns = list(column_to_data.keys())
|
||||
else:
|
||||
selected_columns = selected_columns_params
|
||||
|
||||
state, model = computation_to_model(computation, precal_witness_path, True, selected_columns)
|
||||
compute_model(basepath, data, model, scales_params, selected_columns_params)
|
||||
return state
|
||||
|
||||
|
||||
|
||||
# Error tolerance between zkstats python implementation and python statistics module
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
import pytest
|
||||
|
||||
from zkstats.computation import State, Args, computation_to_model
|
||||
from zkstats.computation import State, computation_to_model, analyze_computation, TComputation, Args
|
||||
from zkstats.ops import (
|
||||
Mean,
|
||||
Median,
|
||||
@@ -25,9 +25,9 @@ from .helpers import assert_result, compute, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUI
|
||||
|
||||
|
||||
def nested_computation(state: State, args: Args):
|
||||
x = args['columns_0']
|
||||
y = args['columns_1']
|
||||
z = args['columns_2']
|
||||
x = args["x"]
|
||||
y = args["y"]
|
||||
z = args["z"]
|
||||
out_0 = state.median(x)
|
||||
out_1 = state.geometric_mean(y)
|
||||
out_2 = state.harmonic_mean(x)
|
||||
@@ -63,8 +63,14 @@ def nested_computation(state: State, args: Args):
|
||||
[ERROR_CIRCUIT_DEFAULT],
|
||||
)
|
||||
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
x, y, z = column_0, column_1, column_2
|
||||
state = compute(tmp_path, [x, y, z], nested_computation, scales)
|
||||
data_shape = {"x": len(x), "y": len(y), "z": len(z)}
|
||||
data = {"x": x, "y": y, "z": z}
|
||||
selected_columns, state, model = computation_to_model(nested_computation, precal_witness_path, data_shape, True, error)
|
||||
compute(tmp_path, data, model, scales, selected_columns)
|
||||
# There are 11 ops in the computation
|
||||
|
||||
assert state.current_op_index == 12
|
||||
|
||||
ops = state.ops
|
||||
@@ -152,10 +158,14 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
|
||||
def condition(_x: torch.Tensor):
|
||||
return _x < 4
|
||||
|
||||
def where_and_op(state: State, args: Args):
|
||||
x = args['columns_0']
|
||||
column_name = "x"
|
||||
|
||||
def where_and_op(state, args):
|
||||
x = args[column_name]
|
||||
return op_type(state, state.where(condition(x), x))
|
||||
state = compute(tmp_path, [column], where_and_op, scales)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
_, state, model = computation_to_model(where_and_op, precal_witness_path, {column_name: column.shape}, True, error)
|
||||
compute(tmp_path, {column_name: column}, model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
filtered = column[condition(column)]
|
||||
@@ -174,14 +184,18 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
|
||||
def condition_0(_x: torch.Tensor):
|
||||
return _x > 4
|
||||
|
||||
def where_and_op(state: State, args: Args):
|
||||
x = args['columns_0']
|
||||
y = args['columns_1']
|
||||
def where_and_op(state: State, args: list[torch.Tensor]):
|
||||
x = args["x"]
|
||||
y = args["y"]
|
||||
condition_x = condition_0(x)
|
||||
filtered_x = state.where(condition_x, x)
|
||||
filtered_y = state.where(condition_x, y)
|
||||
return op_type(state, filtered_x, filtered_y)
|
||||
state = compute(tmp_path, [column_0, column_1], where_and_op, scales)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
data_shape = {"x": len(column_0), "y": len(column_1)}
|
||||
data = {"x": column_0, "y": column_1}
|
||||
selected_columns, state, model = computation_to_model(where_and_op, precal_witness_path, data_shape, True ,error)
|
||||
compute(tmp_path, data, model, scales, selected_columns)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
condition_x = condition_0(column_0)
|
||||
@@ -189,3 +203,68 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
|
||||
filtered_y = column_1[condition_x]
|
||||
expected_res = expected_func(filtered_x.tolist(), filtered_y.tolist())
|
||||
assert_result(res_op.result.data, expected_res)
|
||||
|
||||
|
||||
def test_analyze_computation_success():
|
||||
def valid_computation(state, args):
|
||||
x = args["column1"]
|
||||
y = args["column2"]
|
||||
return state.mean(x) + state.median(y)
|
||||
|
||||
result = analyze_computation(valid_computation)
|
||||
assert set(result) == {"column1", "column2"}
|
||||
|
||||
def test_analyze_computation_no_columns():
|
||||
def no_columns_computation(state, args):
|
||||
return state.mean(state.median([1, 2, 3]))
|
||||
|
||||
result = analyze_computation(no_columns_computation)
|
||||
assert result == []
|
||||
|
||||
def test_analyze_computation_multiple_uses():
|
||||
def multiple_uses_computation(state, args):
|
||||
x = args["column1"]
|
||||
y = args["column2"]
|
||||
z = args["column1"] # Using column1 twice
|
||||
return state.mean(x) + state.median(y) + state.sum(z)
|
||||
|
||||
result = analyze_computation(multiple_uses_computation)
|
||||
assert set(result) == {"column1", "column2"}
|
||||
|
||||
def test_analyze_computation_nested_args():
|
||||
def nested_args_computation(state, args):
|
||||
x = args["column1"]["nested"]
|
||||
y = args["column2"]
|
||||
return state.mean(x) + state.median(y)
|
||||
|
||||
result = analyze_computation(nested_args_computation)
|
||||
assert set(result) == {"column1", "column2"}
|
||||
|
||||
def test_analyze_computation_invalid_params():
|
||||
def invalid_params_computation(invalid1, invalid2):
|
||||
return invalid1.mean(invalid2["column"])
|
||||
|
||||
with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
|
||||
analyze_computation(invalid_params_computation)
|
||||
|
||||
def test_analyze_computation_wrong_param_names():
|
||||
def wrong_param_names(state, wrong_name):
|
||||
return state.mean(wrong_name["column"])
|
||||
|
||||
with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
|
||||
analyze_computation(wrong_param_names)
|
||||
|
||||
def test_analyze_computation_dynamic_column_access():
|
||||
def dynamic_column_access(state, args):
|
||||
columns = ["column1", "column2"]
|
||||
return sum(state.mean(args[col]) for col in columns)
|
||||
|
||||
# This won't catch dynamically accessed columns
|
||||
result = analyze_computation(dynamic_column_access)
|
||||
assert result == []
|
||||
|
||||
def test_analyze_computation_lambda():
|
||||
lambda_computation = lambda state, args: state.mean(args["column"])
|
||||
|
||||
with pytest.raises(ValueError, match="Lambda functions are not supported in analyze_computation"):
|
||||
analyze_computation(lambda_computation)
|
||||
|
||||
@@ -16,7 +16,8 @@ def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
|
||||
# "columns_0": [1, 2, 3, 4, 5],
|
||||
# "columns_1": [6, 7, 8, 9, 10],
|
||||
# }
|
||||
data_json = data_to_json_file(data_path, [column_0, column_1])
|
||||
data_json = {"columns_0": column_0, "columns_1": column_1}
|
||||
data_to_json_file(data_path, data_json)
|
||||
# data_commitment is a mapping[scale -> mapping[column_name, commitment_hex]]
|
||||
# {
|
||||
# scale_0: {
|
||||
@@ -51,7 +52,8 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):
|
||||
data_commitment_path = tmp_path / "commitments.json"
|
||||
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
|
||||
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
|
||||
data_to_json_file(data_path, [column_0, column_1])
|
||||
data_json = {"columns_0": column_0, "columns_1": column_1}
|
||||
data_to_json_file(data_path, data_json)
|
||||
scales = [2, 3]
|
||||
generate_data_commitment(data_path, scales, data_commitment_path)
|
||||
with open(data_commitment_path, "r") as f:
|
||||
@@ -63,30 +65,28 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):
|
||||
|
||||
def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
|
||||
data_path = tmp_path / "data.json"
|
||||
data_json = data_to_json_file(data_path, [column_0, column_1])
|
||||
columns = list(data_json.keys())
|
||||
assert len(columns) == 2
|
||||
# Select only the first column from two columns
|
||||
selected_columns = [columns[0]]
|
||||
data_json = {"columns_0": column_0, "columns_1": column_1}
|
||||
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
|
||||
data_to_json_file(data_path, data_json)
|
||||
|
||||
def simple_computation(state, args):
|
||||
x = args['columns_0']
|
||||
return state.mean(x)
|
||||
return state.mean(args["columns_0"])
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
selected_columns, _, model = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
|
||||
# gen settings, setup, prove, verify
|
||||
compute(tmp_path, [column_0, column_1], simple_computation, scales, selected_columns)
|
||||
compute(tmp_path, data_json, model, scales, selected_columns)
|
||||
|
||||
|
||||
def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
data_json_path = tmp_path / "data.json"
|
||||
data_csv_path = tmp_path / "data.csv"
|
||||
data_json = data_to_json_file(data_json_path, [column_0, column_1])
|
||||
data_json = {"columns_0": column_0, "columns_1": column_1}
|
||||
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
|
||||
data_to_json_file(data_json_path, data_json)
|
||||
json_file_to_csv(data_json_path, data_csv_path)
|
||||
|
||||
selected_columns = list(data_json.keys())
|
||||
|
||||
def simple_computation(state, args):
|
||||
x = args['columns_0']
|
||||
return state.mean(x)
|
||||
return state.mean(args["columns_0"])
|
||||
|
||||
sel_data_path = tmp_path / "comb_data.json"
|
||||
model_path = tmp_path / "model.onnx"
|
||||
@@ -98,7 +98,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
generate_data_commitment(data_csv_path, scales, data_commitment_path)
|
||||
|
||||
# Test: `prover_gen_settings` works with csv
|
||||
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True, selected_columns, error)
|
||||
selected_columns, _, model_for_proving = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
|
||||
prover_gen_settings(
|
||||
data_path=data_csv_path,
|
||||
selected_columns=selected_columns,
|
||||
@@ -112,7 +112,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
|
||||
# Test: `prover_gen_settings` works with csv
|
||||
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
|
||||
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False, selected_columns, error)
|
||||
selected_columns, _, model_for_verification = computation_to_model(simple_computation, precal_witness_path, data_shape, False, error)
|
||||
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
|
||||
|
||||
def json_file_to_csv(data_json_path, data_csv_path):
|
||||
@@ -135,7 +135,8 @@ def json_file_to_csv(data_json_path, data_csv_path):
|
||||
|
||||
def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
|
||||
data_json_path = tmp_path / "data.json"
|
||||
data_from_json = data_to_json_file(data_json_path, [column_0, column_1])
|
||||
data_json = {"columns_0": column_0, "columns_1": column_1}
|
||||
data_from_json = data_to_json_file(data_json_path, data_json)
|
||||
|
||||
# Test: csv can be converted to json
|
||||
# 1. Generate a csv file from json
|
||||
|
||||
@@ -5,9 +5,9 @@ import pytest
|
||||
|
||||
import torch
|
||||
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
|
||||
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
|
||||
from zkstats.computation import IModel, IsResultPrecise
|
||||
|
||||
from .helpers import compute_model, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -49,8 +49,8 @@ def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tens
|
||||
)
|
||||
def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, scales: list[float]):
|
||||
expected_res = statistics.linear_regression(column_0.tolist(), column_1.tolist())
|
||||
columns = [column_0, column_1]
|
||||
regression = Regression.create(columns, error)
|
||||
columns = {"columns_0": column_0, "columns_1": column_1}
|
||||
regression = Regression.create(list(columns.values()), error)
|
||||
# shape = [2, 1]
|
||||
actual_res = regression.result
|
||||
assert_result(expected_res.slope, actual_res[0][0])
|
||||
@@ -58,8 +58,7 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
|
||||
class Model(IModel):
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return regression.ezkl(x), regression.result
|
||||
compute_model(tmp_path, columns, Model, scales)
|
||||
|
||||
compute(tmp_path, columns, Model, scales)
|
||||
|
||||
|
||||
def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[list[float]], float], error: float, scales: list[float], columns: list[torch.Tensor]):
|
||||
@@ -70,4 +69,5 @@ def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[li
|
||||
class Model(IModel):
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return op.ezkl(x), op.result
|
||||
compute_model(tmp_path, columns, Model, scales)
|
||||
data = {f"columns_{i}": column for i, column in enumerate(columns)}
|
||||
compute(tmp_path, data, Model, scales)
|
||||
|
||||
@@ -264,30 +264,81 @@ class IModel(nn.Module):
|
||||
|
||||
|
||||
# An computation function. Example:
|
||||
# def computation(state: State, x: list[torch.Tensor]):
|
||||
# out_0 = state.median(x[0])
|
||||
# out_1 = state.median(x[1])
|
||||
# def computation(state: State, args: dict[str, torch.Tensor]):
|
||||
# out_0 = state.median(args["x"])
|
||||
# out_1 = state.median(args["y"])
|
||||
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
TComputation = Callable[[State, dict[str, torch.Tensor]], torch.Tensor]
|
||||
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
|
||||
def analyze_computation(computation: TComputation):
|
||||
source = inspect.getsource(computation)
|
||||
|
||||
# Check if it's a lambda function
|
||||
if source.strip().startswith('lambda'):
|
||||
raise ValueError("Lambda functions are not supported in analyze_computation. Please use a regular function definition instead.")
|
||||
|
||||
# Existing code for regular functions
|
||||
# Correct indentation
|
||||
lines = source.splitlines()
|
||||
min_indent = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
|
||||
corrected_source = '\n'.join(line[min_indent:] for line in lines)
|
||||
|
||||
tree = ast.parse(corrected_source)
|
||||
column_names = set()
|
||||
|
||||
class ComputationVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.valid_params = False
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if len(node.args.args) == 2:
|
||||
if node.args.args[0].arg == 'state' and node.args.args[1].arg == 'args':
|
||||
self.valid_params = True
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
if isinstance(node.value, ast.Name) and node.value.id == 'args':
|
||||
if isinstance(node.slice, ast.Constant):
|
||||
column_names.add(node.slice.value)
|
||||
self.generic_visit(node)
|
||||
|
||||
visitor = ComputationVisitor()
|
||||
visitor.visit(tree)
|
||||
|
||||
if not visitor.valid_params:
|
||||
raise ValueError("The computation function must have two parameters named 'state' and 'args'")
|
||||
|
||||
return list(column_names)
|
||||
|
||||
|
||||
class Args:
|
||||
def __init__(
|
||||
self,
|
||||
columns: list[str],
|
||||
data_shape: dict[str, int],
|
||||
data: list[torch.Tensor],
|
||||
):
|
||||
if len(columns) != len(data):
|
||||
raise ValueError("columns and data must have the same length")
|
||||
column_names = list(data_shape.keys())[:len(data)]
|
||||
self.data_dict = {
|
||||
column_name: d
|
||||
for column_name, d in zip(columns, data)
|
||||
for column_name, d in zip(column_names, data)
|
||||
}
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.data_dict[key]
|
||||
|
||||
|
||||
def computation_to_model(computation: TComputation, precal_witness_path: str, isProver:bool, selected_columns: list[str], error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
|
||||
def computation_to_model(
|
||||
computation: TComputation,
|
||||
precal_witness_path:str,
|
||||
data_shape: dict[str, int],
|
||||
isProver: bool,
|
||||
error: float = DEFAULT_ERROR,
|
||||
) -> tuple[list[str], State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
||||
@@ -295,6 +346,16 @@ def computation_to_model(computation: TComputation, precal_witness_path: str, is
|
||||
:return: A tuple of State and Model. The Model is a torch model that can be used for exporting to onnx.
|
||||
State is a container for intermediate results of computation, which can be useful when debugging.
|
||||
"""
|
||||
|
||||
selected_columns_unordered = analyze_computation(computation)
|
||||
# Preserve the order from data_shape
|
||||
selected_columns = [col for col in data_shape.keys() if col in selected_columns_unordered]
|
||||
assert len(selected_columns) == len(selected_columns_unordered), "Selected columns must match"
|
||||
|
||||
invalid_columns = set(selected_columns) - set(data_shape.keys())
|
||||
if invalid_columns:
|
||||
raise ValueError(f"Computation uses columns not present in data: {invalid_columns}")
|
||||
|
||||
state = State(error)
|
||||
|
||||
state.precal_witness_path = precal_witness_path
|
||||
@@ -302,25 +363,17 @@ def computation_to_model(computation: TComputation, precal_witness_path: str, is
|
||||
|
||||
class Model(IModel):
|
||||
def preprocess(self, x: list[torch.Tensor]) -> None:
|
||||
"""
|
||||
Calculate the witnesses of the computation and store them in the state.
|
||||
"""
|
||||
# In the preprocess step, the operations are calculated and the results are stored in the state.
|
||||
# So we don't need to get the returned result
|
||||
args = Args(selected_columns, x)
|
||||
args = Args(data_shape, x)
|
||||
computation(state, args)
|
||||
state.set_ready_for_exporting_onnx()
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
"""
|
||||
Called by torch.onnx.export.
|
||||
"""
|
||||
args = Args(selected_columns, x)
|
||||
args = Args(data_shape, x)
|
||||
result = computation(state, args)
|
||||
is_computation_result_accurate = state.bools[0]()
|
||||
for op_precise_check in state.bools[1:]:
|
||||
is_op_result_accurate = op_precise_check()
|
||||
is_computation_result_accurate = torch.logical_and(is_computation_result_accurate, is_op_result_accurate)
|
||||
return is_computation_result_accurate, result
|
||||
return state, Model
|
||||
return selected_columns, state, Model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user