mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
fix mean example
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -43,15 +43,20 @@ class State:
|
||||
self.error: float = error
|
||||
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
|
||||
self.current_op_index: Optional[int] = None
|
||||
|
||||
self.witness_array: Optional[list[torch.Tensor]] = None
|
||||
def set_ready_for_exporting_onnx(self) -> None:
|
||||
self.current_op_index = 0
|
||||
|
||||
def set_witness(self,witness_array) -> None:
|
||||
self.witness_array = witness_array
|
||||
def mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mean of the input tensor. The behavior should conform to
|
||||
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
|
||||
"""
|
||||
# if self.witness_array is not None:
|
||||
# print('self.wtiness ', self.witness_array)
|
||||
# return self._call_op([x], Mean, self.witness_array)
|
||||
# else:
|
||||
return self._call_op([x], Mean)
|
||||
|
||||
def median(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -145,7 +150,10 @@ class State:
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
op = op_type.create(x, self.error)
|
||||
if self.witness_array is not None:
|
||||
op = op_type.create(x, self.error, self.witness_array)
|
||||
else:
|
||||
op = op_type.create(x, self.error)
|
||||
self.ops.append(op)
|
||||
return op.result
|
||||
else:
|
||||
@@ -168,10 +176,14 @@ class State:
|
||||
def is_precise() -> IsResultPrecise:
|
||||
return op.ezkl(x)
|
||||
self.bools.append(is_precise)
|
||||
|
||||
# self.x.append(x)
|
||||
|
||||
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
|
||||
# else, return only result
|
||||
# print('all ops:', self.ops)
|
||||
if current_op_index == len_ops - 1:
|
||||
print('final op: ', op)
|
||||
# Sanity check for length of self.ops and self.bools
|
||||
len_bools = len(self.bools)
|
||||
if len_ops != len_bools:
|
||||
@@ -179,14 +191,29 @@ class State:
|
||||
is_precise_aggregated = torch.tensor(1.0)
|
||||
for i in range(len_bools):
|
||||
res = self.bools[i]()
|
||||
# print("hey computation: ", i)
|
||||
# print('self.ops: ', self.ops[i])
|
||||
# print('res: ', res)
|
||||
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
|
||||
return is_precise_aggregated, op.result
|
||||
if isinstance(op, Where):
|
||||
# return as where result
|
||||
return is_precise_aggregated, op.result+x[1]-x[1]
|
||||
else:
|
||||
# return as a single number
|
||||
# return is_precise_aggregated, torch.tensor(40.0)+(x[0]-x[0])[0][0][0]
|
||||
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
elif current_op_index > len_ops - 1:
|
||||
# Sanity check that current op index does not exceed the length of ops
|
||||
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
|
||||
else:
|
||||
# It's not the last operation, just return the result
|
||||
return op.result
|
||||
# for where
|
||||
if isinstance(op, Where):
|
||||
return (op.result+x[1]-x[1])
|
||||
else:
|
||||
# return single float number
|
||||
# return torch.where(x[0], x[1], 9999999)
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
|
||||
class IModel(nn.Module):
|
||||
@@ -207,7 +234,7 @@ class IModel(nn.Module):
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
|
||||
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR, witness_array: Optional[list[torch.Tensor]] = None ) -> tuple[State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
||||
@@ -216,13 +243,17 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
State is a container for intermediate results of computation, which can be useful when debugging.
|
||||
"""
|
||||
state = State(error)
|
||||
|
||||
# if it's verifier
|
||||
if witness_array is not None:
|
||||
state.set_witness(witness_array)
|
||||
|
||||
class Model(IModel):
|
||||
def preprocess(self, x: list[torch.Tensor]) -> None:
|
||||
computation(state, x)
|
||||
state.set_ready_for_exporting_onnx()
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
print('x sy: ')
|
||||
return computation(state, x)
|
||||
return state, Model
|
||||
|
||||
|
||||
@@ -54,6 +54,25 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
def prover_gen_witness_array(
|
||||
data_path:str,
|
||||
selected_columns:list[str],
|
||||
sel_data_path:list[str],
|
||||
prover_model: Type[IModel],
|
||||
witness_array_path:str
|
||||
):
|
||||
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
|
||||
|
||||
circuit = prover_model()
|
||||
# cloned_circuit = circuit.clone()
|
||||
circuit.eval()
|
||||
# be careful of tuple here --> array --> tuple need something like in export_onnx
|
||||
one_witness = circuit.forward(data_tensor_array[0]).data.item()
|
||||
print('one witness: ', one_witness)
|
||||
|
||||
data ={'value':[one_witness]}
|
||||
json.dump(data, open(witness_array_path, 'w'))
|
||||
|
||||
|
||||
def prover_gen_settings(
|
||||
data_path: str,
|
||||
@@ -78,9 +97,16 @@ def prover_gen_settings(
|
||||
:param settings_path: path to store the generated settings file
|
||||
"""
|
||||
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
|
||||
|
||||
|
||||
# circuit = prover_model()
|
||||
# circuit.eval()
|
||||
# # be careful of tuple here --> array --> tuple need something like in export_onnx
|
||||
# one_witness = circuit.forward(data_tensor_array[0]).data.item()
|
||||
# print('one witness: ', one_witness)
|
||||
# export onnx file
|
||||
_export_onnx(prover_model, data_tensor_array, prover_model_path)
|
||||
# print("data tensor: ", data_tensor_array)
|
||||
|
||||
# gen + calibrate setting
|
||||
_gen_settings(sel_data_path, prover_model_path, scale, mode, settings_path)
|
||||
|
||||
@@ -346,7 +372,7 @@ def _gen_settings(
|
||||
# Poseidon is not homomorphic additive, maybe consider Pedersens or Dory commitment.
|
||||
gip_run_args = ezkl.PyRunArgs()
|
||||
gip_run_args.input_visibility = "hashed" # one commitment (values hashed) for each column
|
||||
gip_run_args.param_visibility = "private" # no parameters shown
|
||||
gip_run_args.param_visibility = "fixed" # no parameters shown
|
||||
gip_run_args.output_visibility = "public" # should be `(torch.Tensor(1.0), output)`
|
||||
|
||||
# generate settings
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from abc import ABC, abstractmethod, abstractclassmethod
|
||||
import statistics
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# boolean: either 1.0 or 0.0
|
||||
IsResultPrecise = torch.Tensor
|
||||
MagicNumber = 9999999
|
||||
MagicNumber = 9999999.0
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
@@ -29,16 +30,23 @@ class Where(Operation):
|
||||
# here error is trivial, but here to conform to other functions
|
||||
return cls(torch.where(x[0],x[1], MagicNumber ),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
# print('sellll: ', self.result)
|
||||
bool_array = torch.logical_or(torch.logical_and(x[0], x[1]==self.result), torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
return torch.sum(bool_array.float())==x[1].size()[1]
|
||||
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, witness_array:Optional[list[torch.Tensor]] = None ) -> 'Mean':
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
if witness_array is None:
|
||||
# this is prover
|
||||
print('provvv')
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
else:
|
||||
# this is verifier
|
||||
print('verrrr')
|
||||
return cls(witness_array[0], error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
@@ -200,8 +208,13 @@ class Mode(Operation):
|
||||
old_size = x.size()[1]
|
||||
x = torch.where(x==MagicNumber, min_x-1, x)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
|
||||
return torch.sum(result) == old_size
|
||||
|
||||
count_check = 0
|
||||
for ele in x[0]:
|
||||
bool1 = torch.sum((x==ele[0]).float())<=count_equal
|
||||
bool2 = ele[0]==min_x-1
|
||||
count_check += torch.logical_or(bool1, bool2)
|
||||
return count_check ==old_size
|
||||
|
||||
|
||||
class PStdev(Operation):
|
||||
|
||||
Reference in New Issue
Block a user