generalize from mean

This commit is contained in:
JernKunpittaya
2024-05-04 10:55:16 +07:00
parent 57ea3ed3c6
commit 6f7e38405b
4 changed files with 126 additions and 99 deletions

File diff suppressed because one or more lines are too long

View File

@@ -3,6 +3,7 @@ from typing import Callable, Type, Optional, Union
import torch
from torch import nn
import json
from .ops import (
Operation,
@@ -43,20 +44,23 @@ class State:
self.error: float = error
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
self.current_op_index: Optional[int] = None
self.witness_array: Optional[list[torch.Tensor]] = None
self.precal_witness_path: str = None
self.precal_witness:dict = {}
self.isProver:bool = None
def set_ready_for_exporting_onnx(self) -> None:
self.current_op_index = 0
def set_witness(self,witness_array) -> None:
self.witness_array = witness_array
# def set_witness(self,witness_array) -> None:
# self.witness_array = witness_array
# def set_aggregate_witness_path(self,aggregate_witness_path) -> None:
# self.aggregate_witness_path = aggregate_witness_path
# def get_aggregate_witness(self) -> list[torch.Tensor]:
# return self.aggregate_witness
def mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the mean of the input tensor. The behavior should conform to
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
"""
# if self.witness_array is not None:
# print('self.wtiness ', self.witness_array)
# return self._call_op([x], Mean, self.witness_array)
# else:
return self._call_op([x], Mean)
def median(self, x: torch.Tensor) -> torch.Tensor:
@@ -150,10 +154,23 @@ class State:
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
if self.witness_array is not None:
op = op_type.create(x, self.error, self.witness_array)
else:
# for prover
if self.isProver:
print('Prover side')
op = op_type.create(x, self.error)
# print('oppy : ', op)
# print('is check pri 1: ', isinstance(op,Mean))
if isinstance(op,Mean):
self.precal_witness['Mean'] = [op.result.data.item()]
# for verifier
else:
print('Verifier side')
# if isinstance(op,Mean):
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
# tensor_arr = []
# for ele in data:
# tensor_arr.append(torch.tensor(ele))
op = op_type.create(x, self.error, precal_witness)
self.ops.append(op)
return op.result
else:
@@ -199,8 +216,8 @@ class State:
# return as where result
return is_precise_aggregated, op.result+x[1]-x[1]
else:
# return as a single number
# return is_precise_aggregated, torch.tensor(40.0)+(x[0]-x[0])[0][0][0]
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
elif current_op_index > len_ops - 1:
@@ -213,6 +230,9 @@ class State:
else:
# return single float number
# return torch.where(x[0], x[1], 9999999)
# print('oppy else: ', op)
# print('is check else: ', isinstance(op,Mean))
# self.aggregate_witness.append(op.result)
return op.result+(x[0]-x[0])[0][0][0]
@@ -234,7 +254,7 @@ class IModel(nn.Module):
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR, witness_array: Optional[list[torch.Tensor]] = None ) -> tuple[State, Type[IModel]]:
def computation_to_model(computation: TComputation, precal_witness_path:str, isProver:bool ,error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
@@ -244,16 +264,18 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
"""
state = State(error)
# if it's verifier
if witness_array is not None:
state.set_witness(witness_array)
state.precal_witness_path= precal_witness_path
state.isProver = isProver
class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
computation(state, x)
state.set_ready_for_exporting_onnx()
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
print('x sy: ')
# print('x sy: ')
return computation(state, x)
# print('state:: ', state.aggregate_witness_path)
return state, Model

View File

@@ -54,24 +54,24 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
# ===================================================================================================
# ===================================================================================================
def prover_gen_witness_array(
data_path:str,
selected_columns:list[str],
sel_data_path:list[str],
prover_model: Type[IModel],
witness_array_path:str
):
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
# def prover_gen_witness_array(
# data_path:str,
# selected_columns:list[str],
# sel_data_path:list[str],
# prover_model: Type[IModel],
# witness_array_path:str
# ):
# data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)
circuit = prover_model()
# cloned_circuit = circuit.clone()
circuit.eval()
# be careful of tuple here --> array --> tuple need something like in export_onnx
one_witness = circuit.forward(data_tensor_array[0]).data.item()
print('one witness: ', one_witness)
# circuit = prover_model()
# # cloned_circuit = circuit.clone()
# circuit.eval()
# # be careful of tuple here --> array --> tuple need something like in export_onnx
# one_witness = circuit.forward(data_tensor_array[0]).data.item()
# print('one witness: ', one_witness)
data ={'value':[one_witness]}
json.dump(data, open(witness_array_path, 'w'))
# data ={'value':[one_witness]}
# json.dump(data, open(witness_array_path, 'w'))
def prover_gen_settings(

View File

@@ -36,16 +36,20 @@ class Where(Operation):
class Mean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, witness_array:Optional[list[torch.Tensor]] = None ) -> 'Mean':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Mean':
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
if witness_array is None:
if precal_witness is None:
# this is prover
print('provvv')
# print('provvv')
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
else:
# this is verifier
print('verrrr')
return cls(witness_array[0], error)
# print('verrrr')
tensor_arr = []
for ele in precal_witness['Mean']:
tensor_arr.append(torch.tensor(ele))
print("mean tensor arr: ", tensor_arr)
return cls(tensor_arr[0], error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise: