mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
median
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -50,12 +50,7 @@ class State:
|
||||
|
||||
def set_ready_for_exporting_onnx(self) -> None:
|
||||
self.current_op_index = 0
|
||||
# def set_witness(self,witness_array) -> None:
|
||||
# self.witness_array = witness_array
|
||||
# def set_aggregate_witness_path(self,aggregate_witness_path) -> None:
|
||||
# self.aggregate_witness_path = aggregate_witness_path
|
||||
# def get_aggregate_witness(self) -> list[torch.Tensor]:
|
||||
# return self.aggregate_witness
|
||||
|
||||
def mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mean of the input tensor. The behavior should conform to
|
||||
@@ -158,19 +153,16 @@ class State:
|
||||
if self.isProver:
|
||||
print('Prover side')
|
||||
op = op_type.create(x, self.error)
|
||||
# print('oppy : ', op)
|
||||
# print('is check pri 1: ', isinstance(op,Mean))
|
||||
if isinstance(op,Mean):
|
||||
self.precal_witness['Mean'] = [op.result.data.item()]
|
||||
elif isinstance(op, Median):
|
||||
self.precal_witness['Median'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
# for verifier
|
||||
else:
|
||||
print('Verifier side')
|
||||
# if isinstance(op,Mean):
|
||||
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
||||
# tensor_arr = []
|
||||
# for ele in data:
|
||||
# tensor_arr.append(torch.tensor(ele))
|
||||
op = op_type.create(x, self.error, precal_witness)
|
||||
op = op_type.create(x, self.error, precal_witness)
|
||||
print('finish create')
|
||||
self.ops.append(op)
|
||||
return op.result
|
||||
else:
|
||||
@@ -193,12 +185,10 @@ class State:
|
||||
def is_precise() -> IsResultPrecise:
|
||||
return op.ezkl(x)
|
||||
self.bools.append(is_precise)
|
||||
|
||||
# self.x.append(x)
|
||||
|
||||
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
|
||||
# else, return only result
|
||||
# print('all ops:', self.ops)
|
||||
|
||||
if current_op_index == len_ops - 1:
|
||||
print('final op: ', op)
|
||||
# Sanity check for length of self.ops and self.bools
|
||||
|
||||
@@ -71,22 +71,37 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class Median(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None ):
|
||||
if precal_witness is None:
|
||||
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
|
||||
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
|
||||
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
|
||||
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
else:
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Median']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
super().__init__(tensor_arr[0], error)
|
||||
self.lower = torch.nn.Parameter(data = tensor_arr[1], requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = tensor_arr[2], requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Median':
|
||||
if precal_witness is None:
|
||||
return cls(x[0], error)
|
||||
else:
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Median']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
return cls(tensor_arr[0],error, precal_witness)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
|
||||
Reference in New Issue
Block a user