None as default, where, support torch with state

This commit is contained in:
JernKunpittaya
2024-05-14 14:39:11 +07:00
parent 939f91fa5b
commit c849e60251
9 changed files with 963 additions and 351 deletions

View File

@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6, 7],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -19,7 +19,6 @@ from .ops import (
Covariance,
Correlation,
Regression,
Where,
IsResultPrecise,
)
@@ -139,15 +138,15 @@ class State:
return self._call_op([x, y], Regression)
# WHERE operation
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
:param filter: A boolean tensor serves as a filter
:param _filter: A boolean tensor serves as a filter
:param x: A tensor to be filtered
:return: filtered tensor
"""
return self._call_op([filter, x], Where)
return torch.where(_filter, x, x-x+MagicNumber)
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
@@ -210,16 +209,12 @@ class State:
# print('Verifier side create')
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
op = op_type.create(x, self.error, precal_witness, self.op_dict)
# dont need to include Where
if not isinstance(op, Where):
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.op_dict[op_class_str] = 1
else:
self.op_dict[op_class_str]+=1
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.op_dict[op_class_str] = 1
else:
self.op_dict[op_class_str]+=1
self.ops.append(op)
if isinstance(op, Where):
return torch.where(x[0], x[1], MagicNumber)
return op.result
else:
# Copy the current op index to a local variable since self.current_op_index will be incremented.
@@ -255,24 +250,15 @@ class State:
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
if isinstance(op, Where):
# print('Only where')
return is_precise_aggregated, torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
else:
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
elif current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
# for where
if isinstance(op, Where):
# print('many ops incl where')
return torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
else:
return op.result+(x[0]-x[0])[0][0][0]
return op.result+(x[0]-x[0])[0][0][0]
class IModel(nn.Module):
@@ -314,7 +300,11 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
# print('x sy: ')
return computation(state, x)
result = computation(state, x)
if len(result) ==1:
return x[0][0][0][0]-x[0][0][0][0]+torch.tensor(1.0), result
else:
return result
# print('state:: ', state.aggregate_witness_path)
return state, Model

View File

@@ -24,20 +24,10 @@ class Operation(ABC):
...
class Where(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Where':
# here error is trivial, but here to conform to other functions
# just dummy result, since not using it anyway because we dont want to expose direct result from where
return cls(torch.tensor(1),error)
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
return torch.tensor(True)
class Mean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {} ) -> 'Mean':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Mean':
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
if precal_witness is None:
# this is prover
@@ -46,7 +36,9 @@ class Mean(Operation):
else:
# this is verifier
# print('verrrr')
if 'Mean' not in op_dict:
if op_dict is None:
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
elif 'Mean' not in op_dict:
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['Mean_'+str(op_dict['Mean'])][0]), error)
@@ -71,7 +63,7 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
class Median(Operation):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict= {} ):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ):
if precal_witness is None:
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
@@ -85,7 +77,11 @@ class Median(Operation):
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
else:
if 'Median' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
elif 'Median' not in op_dict:
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
@@ -96,7 +92,7 @@ class Median(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict= {} ) -> 'Median':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Median':
return cls(x[0],error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
@@ -130,14 +126,16 @@ class Median(Operation):
class GeometricMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'GeometricMean':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'GeometricMean':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.exp(torch.mean(torch.log(x_1d)))
return cls(result, error)
else:
if 'GeometricMean' not in op_dict:
if op_dict is None:
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
elif 'GeometricMean' not in op_dict:
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['GeometricMean_'+str(op_dict['GeometricMean'])][0]), error)
@@ -152,14 +150,16 @@ class GeometricMean(Operation):
class HarmonicMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict = {}) -> 'HarmonicMean':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'HarmonicMean':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
return cls(result, error)
else:
if 'HarmonicMean' not in op_dict:
if op_dict is None:
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
elif 'HarmonicMean' not in op_dict:
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['HarmonicMean_'+str(op_dict['HarmonicMean'])][0]), error)
@@ -220,7 +220,7 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
class Mode(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Mode':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Mode':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
@@ -228,7 +228,9 @@ class Mode(Operation):
result = torch.tensor(mode_within(x_1d, 0))
return cls(result, error)
else:
if 'Mode' not in op_dict:
if op_dict is None:
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
elif 'Mode' not in op_dict:
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['Mode_'+str(op_dict['Mode'])][0]), error)
@@ -251,7 +253,7 @@ class Mode(Operation):
class PStdev(Operation):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -259,7 +261,10 @@ class PStdev(Operation):
result = torch.sqrt(torch.var(x_1d, correction = 0))
super().__init__(result, error)
else:
if 'PStdev' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
elif 'PStdev' not in op_dict:
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
else:
@@ -268,7 +273,7 @@ class PStdev(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PStdev':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PStdev':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
@@ -283,7 +288,7 @@ class PStdev(Operation):
class PVariance(Operation):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -291,7 +296,10 @@ class PVariance(Operation):
result = torch.var(x_1d, correction = 0)
super().__init__(result, error)
else:
if 'PVariance' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
elif 'PVariance' not in op_dict:
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
else:
@@ -299,7 +307,7 @@ class PVariance(Operation):
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][1]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PVariance':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PVariance':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
@@ -315,7 +323,7 @@ class PVariance(Operation):
class Stdev(Operation):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -323,7 +331,10 @@ class Stdev(Operation):
result = torch.sqrt(torch.var(x_1d, correction = 1))
super().__init__(result, error)
else:
if 'Stdev' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
elif 'Stdev' not in op_dict:
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
else:
@@ -332,7 +343,7 @@ class Stdev(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Stdev':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Stdev':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
@@ -347,7 +358,7 @@ class Stdev(Operation):
class Variance(Operation):
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -355,7 +366,10 @@ class Variance(Operation):
result = torch.var(x_1d, correction = 1)
super().__init__(result, error)
else:
if 'Variance' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
elif 'Variance' not in op_dict:
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
else:
@@ -364,7 +378,7 @@ class Variance(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Variance':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Variance':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
@@ -381,7 +395,7 @@ class Variance(Operation):
class Covariance(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -396,7 +410,11 @@ class Covariance(Operation):
super().__init__(result, error)
else:
if 'Covariance' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
elif 'Covariance' not in op_dict:
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
@@ -406,7 +424,7 @@ class Covariance(Operation):
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][2]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Covariance':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Covariance':
return cls(x[0], x[1], error, precal_witness, op_dict)
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
@@ -438,7 +456,7 @@ def covariance_for_corr(x_adj_mean: torch.Tensor,y_adj_mean: torch.Tensor,size_x
class Correlation(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness, op_dict:dict = {}):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
@@ -455,7 +473,14 @@ class Correlation(Operation):
super().__init__(result, error)
else:
if 'Correlation' not in op_dict:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
elif 'Correlation' not in op_dict:
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
@@ -472,7 +497,7 @@ class Correlation(Operation):
@classmethod
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Correlation':
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Correlation':
return cls(args[0], args[1], error, precal_witness, op_dict)
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
@@ -500,7 +525,7 @@ def stacked_x(args: list[float]):
class Regression(Operation):
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:dict=None, op_dict:dict = {}):
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1ds = [to_1d(i) for i in xs]
fil_x_1ds=[]
@@ -517,7 +542,9 @@ class Regression(Operation):
# print('result: ', result)
super().__init__(result, error)
else:
if 'Regression' not in op_dict:
if op_dict is None:
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
elif 'Regression' not in op_dict:
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
else:
result = torch.tensor(precal_witness['Regression_'+str(op_dict['Regression'])]).reshape(1,-1,1)
@@ -529,7 +556,7 @@ class Regression(Operation):
@classmethod
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Regression':
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Regression':
xs = args[:-1]
y = args[-1]
return cls(xs, y, error, precal_witness, op_dict)