mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-10 05:57:55 -05:00
380 lines
17 KiB
Python
380 lines
17 KiB
Python
from abc import abstractmethod
|
|
from typing import Callable, Type, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
import json
|
|
|
|
from .ops import (
|
|
Operation,
|
|
Mean,
|
|
Median,
|
|
GeometricMean,
|
|
HarmonicMean,
|
|
Mode,
|
|
PStdev,
|
|
PVariance,
|
|
Stdev,
|
|
Variance,
|
|
Covariance,
|
|
Correlation,
|
|
Regression,
|
|
IsResultPrecise,
|
|
)
|
|
|
|
|
|
DEFAULT_ERROR = 0.01
|
|
MagicNumber = 99.999
|
|
|
|
|
|
class State:
|
|
"""
|
|
State is a container for intermediate results of computation.
|
|
|
|
Stage 1 (current_op_index is None): for every call to State (mean, median, etc.), result
|
|
is calculated and temporarily stored in the state. Call `set_ready_for_exporting_onnx` to indicate
|
|
Stage 2: all operations are calculated and results are ready to be used. Call `set_ready_for_exporting_onnx`
|
|
to indicate it's ready to generate settings.
|
|
Stage 3 (current_op_index is not None): when exporting to onnx, when the operations are called, the results and
|
|
the conditions are popped from the state and filled in the onnx graph.
|
|
"""
|
|
def __init__(self, error: float) -> None:
|
|
self.ops: list[Operation] = []
|
|
self.bools: list[Callable[[], torch.Tensor]] = []
|
|
self.error: float = error
|
|
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
|
|
self.current_op_index: Optional[int] = None
|
|
self.precal_witness_path: str = None
|
|
self.precal_witness:dict = {}
|
|
self.isProver:bool = None
|
|
self.op_dict:dict={}
|
|
|
|
def set_ready_for_exporting_onnx(self) -> None:
|
|
self.current_op_index = 0
|
|
|
|
def mean(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the mean of the input tensor. The behavior should conform to
|
|
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
|
|
"""
|
|
return self._call_op([x], Mean)
|
|
|
|
def median(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the median of the input tensor. The behavior should conform to
|
|
[statistics.median](https://docs.python.org/3/library/statistics.html#statistics.median) in Python standard library.
|
|
"""
|
|
return self._call_op([x], Median)
|
|
|
|
def geometric_mean(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the geometric mean of the input tensor. The behavior should conform to
|
|
[statistics.geometric_mean](https://docs.python.org/3/library/statistics.html#statistics.geometric_mean) in Python standard library.
|
|
"""
|
|
return self._call_op([x], GeometricMean)
|
|
|
|
def harmonic_mean(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the harmonic mean of the input tensor. The behavior should conform to
|
|
[statistics.harmonic_mean](https://docs.python.org/3/library/statistics.html#statistics.harmonic_mean) in Python standard library.
|
|
"""
|
|
return self._call_op([x], HarmonicMean)
|
|
|
|
def mode(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the mode of the input tensor. The behavior should conform to
|
|
[statistics.mode](https://docs.python.org/3/library/statistics.html#statistics.mode) in Python standard library.
|
|
"""
|
|
return self._call_op([x], Mode)
|
|
|
|
def pstdev(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the population standard deviation of the input tensor. The behavior should conform to
|
|
[statistics.pstdev](https://docs.python.org/3/library/statistics.html#statistics.pstdev) in Python standard library.
|
|
"""
|
|
return self._call_op([x], PStdev)
|
|
|
|
def pvariance(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the population variance of the input tensor. The behavior should conform to
|
|
[statistics.pvariance](https://docs.python.org/3/library/statistics.html#statistics.pvariance) in Python standard library.
|
|
"""
|
|
return self._call_op([x], PVariance)
|
|
|
|
def stdev(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the sample standard deviation of the input tensor. The behavior should conform to
|
|
[statistics.stdev](https://docs.python.org/3/library/statistics.html#statistics.stdev) in Python standard library.
|
|
"""
|
|
return self._call_op([x], Stdev)
|
|
|
|
def variance(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the sample variance of the input tensor. The behavior should conform to
|
|
[statistics.variance](https://docs.python.org/3/library/statistics.html#statistics.variance) in Python standard library.
|
|
"""
|
|
return self._call_op([x], Variance)
|
|
|
|
def covariance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the covariance of x and y. The behavior should conform to
|
|
[statistics.covariance](https://docs.python.org/3/library/statistics.html#statistics.covariance) in Python standard library.
|
|
"""
|
|
return self._call_op([x, y], Covariance)
|
|
|
|
def correlation(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the correlation of x and y. The behavior should conform to
|
|
[statistics.correlation](https://docs.python.org/3/library/statistics.html#statistics.correlation) in Python standard library.
|
|
"""
|
|
return self._call_op([x, y], Correlation)
|
|
|
|
def linear_regression(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the linear regression of x and y. The behavior should conform to
|
|
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
|
|
"""
|
|
# hence support only one x for now
|
|
return self._call_op([x, y], Regression)
|
|
|
|
# WHERE operation
|
|
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
|
|
|
|
:param _filter: A boolean tensor serves as a filter
|
|
:param x: A tensor to be filtered
|
|
:return: filtered tensor
|
|
"""
|
|
return torch.where(_filter, x, x-x+MagicNumber)
|
|
|
|
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
|
if self.current_op_index is None:
|
|
# for prover
|
|
if self.isProver:
|
|
# print('Prover side create')
|
|
op = op_type.create(x, self.error)
|
|
|
|
# Single witness aka result
|
|
if isinstance(op,Mean) or isinstance(op,GeometricMean) or isinstance(op, HarmonicMean) or isinstance(op, Mode):
|
|
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
|
if op_class_str not in self.op_dict:
|
|
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
|
|
self.op_dict[op_class_str] = 1
|
|
else:
|
|
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
|
|
self.op_dict[op_class_str]+=1
|
|
elif isinstance(op, Median):
|
|
if 'Median' not in self.op_dict:
|
|
self.precal_witness['Median_0'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
|
self.op_dict['Median']=1
|
|
else:
|
|
self.precal_witness['Median_'+str(self.op_dict['Median'])] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
|
self.op_dict['Median']+=1
|
|
# std + variance stuffs
|
|
elif isinstance(op, PStdev) or isinstance(op, PVariance) or isinstance(op, Stdev) or isinstance(op, Variance):
|
|
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
|
if op_class_str not in self.op_dict:
|
|
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
|
|
self.op_dict[op_class_str] = 1
|
|
else:
|
|
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
|
|
self.op_dict[op_class_str]+=1
|
|
elif isinstance(op, Covariance):
|
|
if 'Covariance' not in self.op_dict:
|
|
self.precal_witness['Covariance_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
|
self.op_dict['Covariance']=1
|
|
else:
|
|
self.precal_witness['Covariance_'+str(self.op_dict['Covariance'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
|
self.op_dict['Covariance']+=1
|
|
elif isinstance(op, Correlation):
|
|
if 'Correlation' not in self.op_dict:
|
|
self.precal_witness['Correlation_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
|
self.op_dict['Correlation']=1
|
|
else:
|
|
self.precal_witness['Correlation_'+str(self.op_dict['Correlation'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
|
self.op_dict['Correlation']+=1
|
|
elif isinstance(op, Regression):
|
|
result_array = []
|
|
for ele in op.result.data:
|
|
result_array.append(ele[0].item())
|
|
if 'Regression' not in self.op_dict:
|
|
self.precal_witness['Regression_0'] = [result_array]
|
|
self.op_dict['Regression']=1
|
|
else:
|
|
self.precal_witness['Regression_'+str(self.op_dict['Regression'])] = [result_array]
|
|
self.op_dict['Regression']+=1
|
|
# for ele in op.result.data[0]:
|
|
# result_array.append(ele[0].item())
|
|
# if 'Regression' not in self.op_dict:
|
|
# self.precal_witness['Regression_0'] = [result_array]
|
|
# self.op_dict['Regression']=1
|
|
# else:
|
|
# self.precal_witness['Regression_'+str(self.op_dict['Regression'])] = [result_array]
|
|
# self.op_dict['Regression']+=1
|
|
# for verifier
|
|
else:
|
|
# print('Verifier side create')
|
|
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
|
op = op_type.create(x, self.error, precal_witness, self.op_dict)
|
|
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
|
if op_class_str not in self.op_dict:
|
|
self.op_dict[op_class_str] = 1
|
|
else:
|
|
self.op_dict[op_class_str]+=1
|
|
self.ops.append(op)
|
|
return op.result
|
|
else:
|
|
# Copy the current op index to a local variable since self.current_op_index will be incremented.
|
|
current_op_index = self.current_op_index
|
|
# Sanity check that current op index is not out of bound
|
|
len_ops = len(self.ops)
|
|
if current_op_index >= len_ops:
|
|
raise Exception(f"current_op_index out of bound: {current_op_index=} >= {len_ops=}")
|
|
|
|
op = self.ops[current_op_index]
|
|
# Sanity check that the operation type matches the current op type
|
|
if not isinstance(op, op_type):
|
|
raise Exception(f"operation type mismatch: {op_type=} != {type(op)=}")
|
|
|
|
# Increment the current op index
|
|
self.current_op_index += 1
|
|
|
|
# Push the ezkl condition, which is checked only in the last operation
|
|
def is_precise() -> IsResultPrecise:
|
|
return op.ezkl(x)
|
|
self.bools.append(is_precise)
|
|
|
|
if current_op_index > len_ops - 1:
|
|
# Sanity check that current op index does not exceed the length of ops
|
|
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
|
|
if self.isProver:
|
|
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
|
return op.result+(x[0]-x[0])[0][0]
|
|
|
|
|
|
class IModel(nn.Module):
|
|
@abstractmethod
|
|
def preprocess(self, x: list[torch.Tensor]) -> None:
|
|
...
|
|
|
|
@abstractmethod
|
|
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
|
...
|
|
|
|
|
|
# An computation function. Example:
|
|
# def computation(state: State, args: dict[str, torch.Tensor]):
|
|
# out_0 = state.median(args["x"])
|
|
# out_1 = state.median(args["y"])
|
|
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
|
|
TComputation = Callable[[State, dict[str, torch.Tensor]], torch.Tensor]
|
|
|
|
|
|
import ast
|
|
import inspect
|
|
|
|
|
|
def analyze_computation(computation: TComputation):
|
|
source = inspect.getsource(computation)
|
|
|
|
# Check if it's a lambda function
|
|
if source.strip().startswith('lambda'):
|
|
raise ValueError("Lambda functions are not supported in analyze_computation. Please use a regular function definition instead.")
|
|
|
|
# Existing code for regular functions
|
|
# Correct indentation
|
|
lines = source.splitlines()
|
|
min_indent = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
|
|
corrected_source = '\n'.join(line[min_indent:] for line in lines)
|
|
|
|
tree = ast.parse(corrected_source)
|
|
column_names = set()
|
|
|
|
class ComputationVisitor(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.valid_params = False
|
|
|
|
def visit_FunctionDef(self, node):
|
|
if len(node.args.args) == 2:
|
|
if node.args.args[0].arg == 'state' and node.args.args[1].arg == 'args':
|
|
self.valid_params = True
|
|
self.generic_visit(node)
|
|
|
|
def visit_Subscript(self, node):
|
|
if isinstance(node.value, ast.Name) and node.value.id == 'args':
|
|
if isinstance(node.slice, ast.Constant):
|
|
column_names.add(node.slice.value)
|
|
self.generic_visit(node)
|
|
|
|
visitor = ComputationVisitor()
|
|
visitor.visit(tree)
|
|
|
|
if not visitor.valid_params:
|
|
raise ValueError("The computation function must have two parameters named 'state' and 'args'")
|
|
|
|
return list(column_names)
|
|
|
|
|
|
class Args:
|
|
def __init__(
|
|
self,
|
|
data_shape: dict[str, int],
|
|
data: list[torch.Tensor],
|
|
):
|
|
column_names = list(data_shape.keys())[:len(data)]
|
|
self.data_dict = {
|
|
column_name: d
|
|
for column_name, d in zip(column_names, data)
|
|
}
|
|
|
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
return self.data_dict[key]
|
|
|
|
|
|
def computation_to_model(
|
|
computation: TComputation,
|
|
precal_witness_path:str,
|
|
data_shape: dict[str, int],
|
|
isProver: bool,
|
|
error: float = DEFAULT_ERROR,
|
|
) -> tuple[list[str], State, Type[IModel]]:
|
|
"""
|
|
Create a torch model from a `computation` function defined by user
|
|
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
|
:param error: The error tolerance for the computation.
|
|
:return: A tuple of State and Model. The Model is a torch model that can be used for exporting to onnx.
|
|
State is a container for intermediate results of computation, which can be useful when debugging.
|
|
"""
|
|
|
|
selected_columns_unordered = analyze_computation(computation)
|
|
# Preserve the order from data_shape
|
|
selected_columns = [col for col in data_shape.keys() if col in selected_columns_unordered]
|
|
assert len(selected_columns) == len(selected_columns_unordered), "Selected columns must match"
|
|
|
|
invalid_columns = set(selected_columns) - set(data_shape.keys())
|
|
if invalid_columns:
|
|
raise ValueError(f"Computation uses columns not present in data: {invalid_columns}")
|
|
|
|
state = State(error)
|
|
|
|
state.precal_witness_path = precal_witness_path
|
|
state.isProver = isProver
|
|
|
|
class Model(IModel):
|
|
def preprocess(self, x: list[torch.Tensor]) -> None:
|
|
args = Args(data_shape, x)
|
|
computation(state, args)
|
|
state.set_ready_for_exporting_onnx()
|
|
|
|
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
|
args = Args(data_shape, x)
|
|
result = computation(state, args)
|
|
is_computation_result_accurate = state.bools[0]()
|
|
for op_precise_check in state.bools[1:]:
|
|
is_op_result_accurate = op_precise_check()
|
|
is_computation_result_accurate = torch.logical_and(is_computation_result_accurate, is_op_result_accurate)
|
|
return is_computation_result_accurate, result
|
|
return selected_columns, state, Model
|
|
|