fix mean example

This commit is contained in:
JernKunpittaya
2024-05-03 21:29:53 +07:00
parent a35c0af4b5
commit 57ea3ed3c6
4 changed files with 163 additions and 68 deletions

File diff suppressed because one or more lines are too long

View File

@@ -43,15 +43,20 @@ class State:
self.error: float = error
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
self.current_op_index: Optional[int] = None
self.witness_array: Optional[list[torch.Tensor]] = None
def set_ready_for_exporting_onnx(self) -> None:
self.current_op_index = 0
def set_witness(self,witness_array) -> None:
self.witness_array = witness_array
def mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the mean of the input tensor. The behavior should conform to
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
"""
# if self.witness_array is not None:
# print('self.wtiness ', self.witness_array)
# return self._call_op([x], Mean, self.witness_array)
# else:
return self._call_op([x], Mean)
def median(self, x: torch.Tensor) -> torch.Tensor:
@@ -145,7 +150,10 @@ class State:
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
op = op_type.create(x, self.error)
if self.witness_array is not None:
op = op_type.create(x, self.error, self.witness_array)
else:
op = op_type.create(x, self.error)
self.ops.append(op)
return op.result
else:
@@ -168,10 +176,14 @@ class State:
def is_precise() -> IsResultPrecise:
return op.ezkl(x)
self.bools.append(is_precise)
# self.x.append(x)
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result
# print('all ops:', self.ops)
if current_op_index == len_ops - 1:
print('final op: ', op)
# Sanity check for length of self.ops and self.bools
len_bools = len(self.bools)
if len_ops != len_bools:
@@ -179,14 +191,29 @@ class State:
is_precise_aggregated = torch.tensor(1.0)
for i in range(len_bools):
res = self.bools[i]()
# print("hey computation: ", i)
# print('self.ops: ', self.ops[i])
# print('res: ', res)
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
return is_precise_aggregated, op.result
if isinstance(op, Where):
# return as where result
return is_precise_aggregated, op.result+x[1]-x[1]
else:
# return as a single number
# return is_precise_aggregated, torch.tensor(40.0)+(x[0]-x[0])[0][0][0]
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
elif current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
# It's not the last operation, just return the result
return op.result
# for where
if isinstance(op, Where):
return (op.result+x[1]-x[1])
else:
# return single float number
# return torch.where(x[0], x[1], 9999999)
return op.result+(x[0]-x[0])[0][0][0]
class IModel(nn.Module):
@@ -207,7 +234,7 @@ class IModel(nn.Module):
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR, witness_array: Optional[list[torch.Tensor]] = None ) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
@@ -216,13 +243,17 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
State is a container for intermediate results of computation, which can be useful when debugging.
"""
state = State(error)
# if it's verifier
if witness_array is not None:
state.set_witness(witness_array)
class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
computation(state, x)
state.set_ready_for_exporting_onnx()
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
print('x sy: ')
return computation(state, x)
return state, Model

View File

@@ -54,6 +54,25 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
# ===================================================================================================
# ===================================================================================================
def prover_gen_witness_array(
data_path:str,
selected_columns:list[str],
sel_data_path:list[str],
prover_model: Type[IModel],
witness_array_path:str
):
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
circuit = prover_model()
# cloned_circuit = circuit.clone()
circuit.eval()
# be careful of tuple here --> array --> tuple need something like in export_onnx
one_witness = circuit.forward(data_tensor_array[0]).data.item()
print('one witness: ', one_witness)
data ={'value':[one_witness]}
json.dump(data, open(witness_array_path, 'w'))
def prover_gen_settings(
data_path: str,
@@ -78,9 +97,16 @@ def prover_gen_settings(
:param settings_path: path to store the generated settings file
"""
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
# circuit = prover_model()
# circuit.eval()
# # be careful of tuple here --> array --> tuple need something like in export_onnx
# one_witness = circuit.forward(data_tensor_array[0]).data.item()
# print('one witness: ', one_witness)
# export onnx file
_export_onnx(prover_model, data_tensor_array, prover_model_path)
# print("data tensor: ", data_tensor_array)
# gen + calibrate setting
_gen_settings(sel_data_path, prover_model_path, scale, mode, settings_path)
@@ -346,7 +372,7 @@ def _gen_settings(
# Poseidon is not homomorphic additive, maybe consider Pedersens or Dory commitment.
gip_run_args = ezkl.PyRunArgs()
gip_run_args.input_visibility = "hashed" # one commitment (values hashed) for each column
gip_run_args.param_visibility = "private" # no parameters shown
gip_run_args.param_visibility = "fixed" # no parameters shown
gip_run_args.output_visibility = "public" # should be `(torch.Tensor(1.0), output)`
# generate settings

View File

@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod, abstractclassmethod
import statistics
from typing import Optional
import numpy as np
import torch
# boolean: either 1.0 or 0.0
IsResultPrecise = torch.Tensor
MagicNumber = 9999999
MagicNumber = 9999999.0
class Operation(ABC):
@@ -29,16 +30,23 @@ class Where(Operation):
# here error is trivial, but here to conform to other functions
return cls(torch.where(x[0],x[1], MagicNumber ),error)
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
# print('sellll: ', self.result)
bool_array = torch.logical_or(torch.logical_and(x[0], x[1]==self.result), torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
return torch.sum(bool_array.float())==x[1].size()[1]
class Mean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
def create(cls, x: list[torch.Tensor], error: float, witness_array:Optional[list[torch.Tensor]] = None ) -> 'Mean':
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
if witness_array is None:
# this is prover
print('provvv')
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
else:
# this is verifier
print('verrrr')
return cls(witness_array[0], error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
@@ -200,8 +208,13 @@ class Mode(Operation):
old_size = x.size()[1]
x = torch.where(x==MagicNumber, min_x-1, x)
count_equal = torch.sum((x==self.result).float())
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
return torch.sum(result) == old_size
count_check = 0
for ele in x[0]:
bool1 = torch.sum((x==ele[0]).float())<=count_equal
bool2 = ele[0]==min_x-1
count_check += torch.logical_or(bool1, bool2)
return count_check ==old_size
class PStdev(Operation):