mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
generalize from mean
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -3,6 +3,7 @@ from typing import Callable, Type, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import json
|
||||
|
||||
from .ops import (
|
||||
Operation,
|
||||
@@ -43,20 +44,23 @@ class State:
|
||||
self.error: float = error
|
||||
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
|
||||
self.current_op_index: Optional[int] = None
|
||||
self.witness_array: Optional[list[torch.Tensor]] = None
|
||||
self.precal_witness_path: str = None
|
||||
self.precal_witness:dict = {}
|
||||
self.isProver:bool = None
|
||||
|
||||
def set_ready_for_exporting_onnx(self) -> None:
|
||||
self.current_op_index = 0
|
||||
def set_witness(self,witness_array) -> None:
|
||||
self.witness_array = witness_array
|
||||
# def set_witness(self,witness_array) -> None:
|
||||
# self.witness_array = witness_array
|
||||
# def set_aggregate_witness_path(self,aggregate_witness_path) -> None:
|
||||
# self.aggregate_witness_path = aggregate_witness_path
|
||||
# def get_aggregate_witness(self) -> list[torch.Tensor]:
|
||||
# return self.aggregate_witness
|
||||
def mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mean of the input tensor. The behavior should conform to
|
||||
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
|
||||
"""
|
||||
# if self.witness_array is not None:
|
||||
# print('self.wtiness ', self.witness_array)
|
||||
# return self._call_op([x], Mean, self.witness_array)
|
||||
# else:
|
||||
return self._call_op([x], Mean)
|
||||
|
||||
def median(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -150,10 +154,23 @@ class State:
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
if self.witness_array is not None:
|
||||
op = op_type.create(x, self.error, self.witness_array)
|
||||
else:
|
||||
# for prover
|
||||
if self.isProver:
|
||||
print('Prover side')
|
||||
op = op_type.create(x, self.error)
|
||||
# print('oppy : ', op)
|
||||
# print('is check pri 1: ', isinstance(op,Mean))
|
||||
if isinstance(op,Mean):
|
||||
self.precal_witness['Mean'] = [op.result.data.item()]
|
||||
# for verifier
|
||||
else:
|
||||
print('Verifier side')
|
||||
# if isinstance(op,Mean):
|
||||
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
||||
# tensor_arr = []
|
||||
# for ele in data:
|
||||
# tensor_arr.append(torch.tensor(ele))
|
||||
op = op_type.create(x, self.error, precal_witness)
|
||||
self.ops.append(op)
|
||||
return op.result
|
||||
else:
|
||||
@@ -199,8 +216,8 @@ class State:
|
||||
# return as where result
|
||||
return is_precise_aggregated, op.result+x[1]-x[1]
|
||||
else:
|
||||
# return as a single number
|
||||
# return is_precise_aggregated, torch.tensor(40.0)+(x[0]-x[0])[0][0][0]
|
||||
if self.isProver:
|
||||
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
||||
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
elif current_op_index > len_ops - 1:
|
||||
@@ -213,6 +230,9 @@ class State:
|
||||
else:
|
||||
# return single float number
|
||||
# return torch.where(x[0], x[1], 9999999)
|
||||
# print('oppy else: ', op)
|
||||
# print('is check else: ', isinstance(op,Mean))
|
||||
# self.aggregate_witness.append(op.result)
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
|
||||
@@ -234,7 +254,7 @@ class IModel(nn.Module):
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
|
||||
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR, witness_array: Optional[list[torch.Tensor]] = None ) -> tuple[State, Type[IModel]]:
|
||||
def computation_to_model(computation: TComputation, precal_witness_path:str, isProver:bool ,error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
||||
@@ -244,16 +264,18 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
"""
|
||||
state = State(error)
|
||||
# if it's verifier
|
||||
if witness_array is not None:
|
||||
state.set_witness(witness_array)
|
||||
|
||||
state.precal_witness_path= precal_witness_path
|
||||
state.isProver = isProver
|
||||
|
||||
class Model(IModel):
|
||||
def preprocess(self, x: list[torch.Tensor]) -> None:
|
||||
computation(state, x)
|
||||
state.set_ready_for_exporting_onnx()
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
print('x sy: ')
|
||||
# print('x sy: ')
|
||||
return computation(state, x)
|
||||
# print('state:: ', state.aggregate_witness_path)
|
||||
return state, Model
|
||||
|
||||
|
||||
@@ -54,24 +54,24 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
def prover_gen_witness_array(
|
||||
data_path:str,
|
||||
selected_columns:list[str],
|
||||
sel_data_path:list[str],
|
||||
prover_model: Type[IModel],
|
||||
witness_array_path:str
|
||||
):
|
||||
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
|
||||
# def prover_gen_witness_array(
|
||||
# data_path:str,
|
||||
# selected_columns:list[str],
|
||||
# sel_data_path:list[str],
|
||||
# prover_model: Type[IModel],
|
||||
# witness_array_path:str
|
||||
# ):
|
||||
# data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
|
||||
|
||||
circuit = prover_model()
|
||||
# cloned_circuit = circuit.clone()
|
||||
circuit.eval()
|
||||
# be careful of tuple here --> array --> tuple need something like in export_onnx
|
||||
one_witness = circuit.forward(data_tensor_array[0]).data.item()
|
||||
print('one witness: ', one_witness)
|
||||
# circuit = prover_model()
|
||||
# # cloned_circuit = circuit.clone()
|
||||
# circuit.eval()
|
||||
# # be careful of tuple here --> array --> tuple need something like in export_onnx
|
||||
# one_witness = circuit.forward(data_tensor_array[0]).data.item()
|
||||
# print('one witness: ', one_witness)
|
||||
|
||||
data ={'value':[one_witness]}
|
||||
json.dump(data, open(witness_array_path, 'w'))
|
||||
# data ={'value':[one_witness]}
|
||||
# json.dump(data, open(witness_array_path, 'w'))
|
||||
|
||||
|
||||
def prover_gen_settings(
|
||||
|
||||
@@ -36,16 +36,20 @@ class Where(Operation):
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, witness_array:Optional[list[torch.Tensor]] = None ) -> 'Mean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Mean':
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
if witness_array is None:
|
||||
if precal_witness is None:
|
||||
# this is prover
|
||||
print('provvv')
|
||||
# print('provvv')
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
else:
|
||||
# this is verifier
|
||||
print('verrrr')
|
||||
return cls(witness_array[0], error)
|
||||
# print('verrrr')
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Mean']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
print("mean tensor arr: ", tensor_arr)
|
||||
return cls(tensor_arr[0], error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
|
||||
Reference in New Issue
Block a user