mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
None as default, where, support torch with state
This commit is contained in:
4
examples/1.only_torch/data.json
Normal file
4
examples/1.only_torch/data.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6, 7],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
281
examples/1.only_torch/only_torch.ipynb
Normal file
281
examples/1.only_torch/only_torch.ipynb
Normal file
File diff suppressed because one or more lines are too long
294
examples/2.torch+state/torch+state.ipynb
Normal file
294
examples/2.torch+state/torch+state.ipynb
Normal file
File diff suppressed because one or more lines are too long
4
examples/3.state/data.json
Normal file
4
examples/3.state/data.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
294
examples/3.state/state.ipynb
Normal file
294
examples/3.state/state.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -19,7 +19,6 @@ from .ops import (
|
||||
Covariance,
|
||||
Correlation,
|
||||
Regression,
|
||||
Where,
|
||||
IsResultPrecise,
|
||||
)
|
||||
|
||||
@@ -139,15 +138,15 @@ class State:
|
||||
return self._call_op([x, y], Regression)
|
||||
|
||||
# WHERE operation
|
||||
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
|
||||
|
||||
:param filter: A boolean tensor serves as a filter
|
||||
:param _filter: A boolean tensor serves as a filter
|
||||
:param x: A tensor to be filtered
|
||||
:return: filtered tensor
|
||||
"""
|
||||
return self._call_op([filter, x], Where)
|
||||
return torch.where(_filter, x, x-x+MagicNumber)
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
@@ -210,16 +209,12 @@ class State:
|
||||
# print('Verifier side create')
|
||||
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
||||
op = op_type.create(x, self.error, precal_witness, self.op_dict)
|
||||
# dont need to include Where
|
||||
if not isinstance(op, Where):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.op_dict[op_class_str]+=1
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.op_dict[op_class_str]+=1
|
||||
self.ops.append(op)
|
||||
if isinstance(op, Where):
|
||||
return torch.where(x[0], x[1], MagicNumber)
|
||||
return op.result
|
||||
else:
|
||||
# Copy the current op index to a local variable since self.current_op_index will be incremented.
|
||||
@@ -255,24 +250,15 @@ class State:
|
||||
for i in range(len_bools):
|
||||
res = self.bools[i]()
|
||||
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
|
||||
if isinstance(op, Where):
|
||||
# print('Only where')
|
||||
return is_precise_aggregated, torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
|
||||
else:
|
||||
if self.isProver:
|
||||
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
||||
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
|
||||
if self.isProver:
|
||||
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
||||
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
elif current_op_index > len_ops - 1:
|
||||
# Sanity check that current op index does not exceed the length of ops
|
||||
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
|
||||
else:
|
||||
# for where
|
||||
if isinstance(op, Where):
|
||||
# print('many ops incl where')
|
||||
return torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
|
||||
else:
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
|
||||
class IModel(nn.Module):
|
||||
@@ -314,7 +300,11 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
# print('x sy: ')
|
||||
return computation(state, x)
|
||||
result = computation(state, x)
|
||||
if len(result) ==1:
|
||||
return x[0][0][0][0]-x[0][0][0][0]+torch.tensor(1.0), result
|
||||
else:
|
||||
return result
|
||||
# print('state:: ', state.aggregate_witness_path)
|
||||
return state, Model
|
||||
|
||||
|
||||
111
zkstats/ops.py
111
zkstats/ops.py
@@ -24,20 +24,10 @@ class Operation(ABC):
|
||||
...
|
||||
|
||||
|
||||
class Where(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Where':
|
||||
# here error is trivial, but here to conform to other functions
|
||||
# just dummy result, since not using it anyway because we dont want to expose direct result from where
|
||||
return cls(torch.tensor(1),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
return torch.tensor(True)
|
||||
|
||||
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {} ) -> 'Mean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Mean':
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
if precal_witness is None:
|
||||
# this is prover
|
||||
@@ -46,7 +36,9 @@ class Mean(Operation):
|
||||
else:
|
||||
# this is verifier
|
||||
# print('verrrr')
|
||||
if 'Mean' not in op_dict:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
|
||||
elif 'Mean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mean_'+str(op_dict['Mean'])][0]), error)
|
||||
@@ -71,7 +63,7 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class Median(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict= {} ):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ):
|
||||
if precal_witness is None:
|
||||
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
|
||||
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
|
||||
@@ -85,7 +77,11 @@ class Median(Operation):
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
else:
|
||||
if 'Median' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
|
||||
elif 'Median' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
|
||||
@@ -96,7 +92,7 @@ class Median(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict= {} ) -> 'Median':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Median':
|
||||
return cls(x[0],error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -130,14 +126,16 @@ class Median(Operation):
|
||||
|
||||
class GeometricMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'GeometricMean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'GeometricMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'GeometricMean' not in op_dict:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
|
||||
elif 'GeometricMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_'+str(op_dict['GeometricMean'])][0]), error)
|
||||
@@ -152,14 +150,16 @@ class GeometricMean(Operation):
|
||||
|
||||
class HarmonicMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict = {}) -> 'HarmonicMean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'HarmonicMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'HarmonicMean' not in op_dict:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
|
||||
elif 'HarmonicMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_'+str(op_dict['HarmonicMean'])][0]), error)
|
||||
@@ -220,7 +220,7 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
|
||||
class Mode(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Mode':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Mode':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -228,7 +228,9 @@ class Mode(Operation):
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'Mode' not in op_dict:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
|
||||
elif 'Mode' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mode_'+str(op_dict['Mode'])][0]), error)
|
||||
@@ -251,7 +253,7 @@ class Mode(Operation):
|
||||
|
||||
|
||||
class PStdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -259,7 +261,10 @@ class PStdev(Operation):
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'PStdev' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
|
||||
elif 'PStdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
@@ -268,7 +273,7 @@ class PStdev(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PStdev':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PStdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -283,7 +288,7 @@ class PStdev(Operation):
|
||||
|
||||
|
||||
class PVariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -291,7 +296,10 @@ class PVariance(Operation):
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'PVariance' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
|
||||
elif 'PVariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
@@ -299,7 +307,7 @@ class PVariance(Operation):
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][1]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PVariance':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PVariance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -315,7 +323,7 @@ class PVariance(Operation):
|
||||
|
||||
|
||||
class Stdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -323,7 +331,10 @@ class Stdev(Operation):
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Stdev' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
|
||||
elif 'Stdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
@@ -332,7 +343,7 @@ class Stdev(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Stdev':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Stdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -347,7 +358,7 @@ class Stdev(Operation):
|
||||
|
||||
|
||||
class Variance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -355,7 +366,10 @@ class Variance(Operation):
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Variance' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
|
||||
elif 'Variance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
@@ -364,7 +378,7 @@ class Variance(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Variance':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Variance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -381,7 +395,7 @@ class Variance(Operation):
|
||||
|
||||
|
||||
class Covariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -396,7 +410,11 @@ class Covariance(Operation):
|
||||
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Covariance' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
|
||||
elif 'Covariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
|
||||
@@ -406,7 +424,7 @@ class Covariance(Operation):
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][2]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Covariance':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Covariance':
|
||||
return cls(x[0], x[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -438,7 +456,7 @@ def covariance_for_corr(x_adj_mean: torch.Tensor,y_adj_mean: torch.Tensor,size_x
|
||||
|
||||
|
||||
class Correlation(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness, op_dict:dict = {}):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
@@ -455,7 +473,14 @@ class Correlation(Operation):
|
||||
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Correlation' not in op_dict:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
|
||||
elif 'Correlation' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
|
||||
@@ -472,7 +497,7 @@ class Correlation(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Correlation':
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Correlation':
|
||||
return cls(args[0], args[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
@@ -500,7 +525,7 @@ def stacked_x(args: list[float]):
|
||||
|
||||
|
||||
class Regression(Operation):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:dict=None, op_dict:dict = {}):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
@@ -517,7 +542,9 @@ class Regression(Operation):
|
||||
# print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Regression' not in op_dict:
|
||||
if op_dict is None:
|
||||
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
|
||||
elif 'Regression' not in op_dict:
|
||||
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
|
||||
else:
|
||||
result = torch.tensor(precal_witness['Regression_'+str(op_dict['Regression'])]).reshape(1,-1,1)
|
||||
@@ -529,7 +556,7 @@ class Regression(Operation):
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Regression':
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Regression':
|
||||
xs = args[:-1]
|
||||
y = args[-1]
|
||||
return cls(xs, y, error, precal_witness, op_dict)
|
||||
|
||||
Reference in New Issue
Block a user