mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-10 05:57:55 -05:00
Merge branch 'feat/support-rest-operations' into tmp/support-rest-operations
This commit is contained in:
@@ -47,39 +47,87 @@ class State:
|
||||
self.current_op_index = 0
|
||||
|
||||
def mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mean of the input tensor. The behavior should conform to
|
||||
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], Mean)
|
||||
|
||||
def median(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the median of the input tensor. The behavior should conform to
|
||||
[statistics.median](https://docs.python.org/3/library/statistics.html#statistics.median) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], Median)
|
||||
|
||||
def geometric_mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the geometric mean of the input tensor. The behavior should conform to
|
||||
[statistics.geometric_mean](https://docs.python.org/3/library/statistics.html#statistics.geometric_mean) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], GeometricMean)
|
||||
|
||||
def harmonic_mean(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the harmonic mean of the input tensor. The behavior should conform to
|
||||
[statistics.harmonic_mean](https://docs.python.org/3/library/statistics.html#statistics.harmonic_mean) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], HarmonicMean)
|
||||
|
||||
def mode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mode of the input tensor. The behavior should conform to
|
||||
[statistics.mode](https://docs.python.org/3/library/statistics.html#statistics.mode) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], Mode)
|
||||
|
||||
def pstdev(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the population standard deviation of the input tensor. The behavior should conform to
|
||||
[statistics.pstdev](https://docs.python.org/3/library/statistics.html#statistics.pstdev) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], PStdev)
|
||||
|
||||
def pvariance(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the population variance of the input tensor. The behavior should conform to
|
||||
[statistics.pvariance](https://docs.python.org/3/library/statistics.html#statistics.pvariance) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], PVariance)
|
||||
|
||||
def stdev(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the sample standard deviation of the input tensor. The behavior should conform to
|
||||
[statistics.stdev](https://docs.python.org/3/library/statistics.html#statistics.stdev) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], Stdev)
|
||||
|
||||
def variance(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the sample variance of the input tensor. The behavior should conform to
|
||||
[statistics.variance](https://docs.python.org/3/library/statistics.html#statistics.variance) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x], Variance)
|
||||
|
||||
def covariance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the covariance of x and y. The behavior should conform to
|
||||
[statistics.covariance](https://docs.python.org/3/library/statistics.html#statistics.covariance) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x, y], Covariance)
|
||||
|
||||
def correlation(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the correlation of x and y. The behavior should conform to
|
||||
[statistics.correlation](https://docs.python.org/3/library/statistics.html#statistics.correlation) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x, y], Correlation)
|
||||
|
||||
def linear_regression(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the linear regression of x and y. The behavior should conform to
|
||||
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
|
||||
"""
|
||||
return self._call_op([x, y], Regression)
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
@@ -138,7 +186,6 @@ class IModel(nn.Module):
|
||||
...
|
||||
|
||||
|
||||
|
||||
# An computation function. Example:
|
||||
# def computation(state: State, x: list[torch.Tensor]):
|
||||
# out_0 = state.median(x[0])
|
||||
@@ -150,6 +197,10 @@ TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
||||
:param error: The error tolerance for the computation.
|
||||
:return: A tuple of State and Model. The Model is a torch model that can be used for exporting to onnx.
|
||||
State is a container for intermediate results of computation, which can be useful when debugging.
|
||||
"""
|
||||
state = State(error)
|
||||
|
||||
@@ -160,5 +211,4 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return computation(state, x)
|
||||
|
||||
return state, Model
|
||||
|
||||
@@ -10,7 +10,7 @@ IsResultPrecise = torch.Tensor
|
||||
|
||||
class Operation(ABC):
|
||||
def __init__(self, result: torch.Tensor, error: float):
|
||||
self.result = result
|
||||
self.result = torch.nn.Parameter(data=result, requires_grad=False)
|
||||
self.error = error
|
||||
|
||||
@abstractclassmethod
|
||||
@@ -57,6 +57,7 @@ class Median(Operation):
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
|
||||
return cls(x[0], error)
|
||||
@@ -241,9 +242,11 @@ class Covariance(Operation):
|
||||
y_1d = to_1d(y)
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
|
||||
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
@@ -288,6 +291,7 @@ class Correlation(Operation):
|
||||
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user