This commit is contained in:
JernKunpittaya
2024-05-04 13:20:25 +07:00
parent 6f7e38405b
commit 82f3d8d99f
4 changed files with 93 additions and 72 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -50,12 +50,7 @@ class State:
def set_ready_for_exporting_onnx(self) -> None:
self.current_op_index = 0
# def set_witness(self,witness_array) -> None:
# self.witness_array = witness_array
# def set_aggregate_witness_path(self,aggregate_witness_path) -> None:
# self.aggregate_witness_path = aggregate_witness_path
# def get_aggregate_witness(self) -> list[torch.Tensor]:
# return self.aggregate_witness
def mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the mean of the input tensor. The behavior should conform to
@@ -158,19 +153,16 @@ class State:
if self.isProver:
print('Prover side')
op = op_type.create(x, self.error)
# print('oppy : ', op)
# print('is check pri 1: ', isinstance(op,Mean))
if isinstance(op,Mean):
self.precal_witness['Mean'] = [op.result.data.item()]
elif isinstance(op, Median):
self.precal_witness['Median'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
# for verifier
else:
print('Verifier side')
# if isinstance(op,Mean):
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
# tensor_arr = []
# for ele in data:
# tensor_arr.append(torch.tensor(ele))
op = op_type.create(x, self.error, precal_witness)
op = op_type.create(x, self.error, precal_witness)
print('finish create')
self.ops.append(op)
return op.result
else:
@@ -193,12 +185,10 @@ class State:
def is_precise() -> IsResultPrecise:
return op.ezkl(x)
self.bools.append(is_precise)
# self.x.append(x)
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result
# print('all ops:', self.ops)
if current_op_index == len_ops - 1:
print('final op: ', op)
# Sanity check for length of self.ops and self.bools

View File

@@ -71,22 +71,37 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
class Median(Operation):
def __init__(self, x: torch.Tensor, error: float):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None ):
if precal_witness is None:
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
super().__init__(torch.tensor(np.median(x_1d)), error)
sorted_x = np.sort(x_1d)
len_x = len(x_1d)
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
super().__init__(torch.tensor(np.median(x_1d)), error)
sorted_x = np.sort(x_1d)
len_x = len(x_1d)
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
else:
tensor_arr = []
for ele in precal_witness['Median']:
tensor_arr.append(torch.tensor(ele))
super().__init__(tensor_arr[0], error)
self.lower = torch.nn.Parameter(data = tensor_arr[1], requires_grad=False)
self.upper = torch.nn.Parameter(data = tensor_arr[2], requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Median':
if precal_witness is None:
return cls(x[0], error)
else:
tensor_arr = []
for ele in precal_witness['Median']:
tensor_arr.append(torch.tensor(ele))
return cls(tensor_arr[0],error, precal_witness)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]