Merge branch 'feat/support-rest-operations' into tmp/support-rest-operations

This commit is contained in:
JernKunpittaya
2024-02-16 13:23:25 +07:00
committed by GitHub
2 changed files with 57 additions and 3 deletions

View File

@@ -47,39 +47,87 @@ class State:
self.current_op_index = 0
def mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the mean of the input tensor. The behavior should conform to
[statistics.mean](https://docs.python.org/3/library/statistics.html#statistics.mean) in Python standard library.
"""
return self._call_op([x], Mean)
def median(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the median of the input tensor. The behavior should conform to
[statistics.median](https://docs.python.org/3/library/statistics.html#statistics.median) in Python standard library.
"""
return self._call_op([x], Median)
def geometric_mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the geometric mean of the input tensor. The behavior should conform to
[statistics.geometric_mean](https://docs.python.org/3/library/statistics.html#statistics.geometric_mean) in Python standard library.
"""
return self._call_op([x], GeometricMean)
def harmonic_mean(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the harmonic mean of the input tensor. The behavior should conform to
[statistics.harmonic_mean](https://docs.python.org/3/library/statistics.html#statistics.harmonic_mean) in Python standard library.
"""
return self._call_op([x], HarmonicMean)
def mode(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the mode of the input tensor. The behavior should conform to
[statistics.mode](https://docs.python.org/3/library/statistics.html#statistics.mode) in Python standard library.
"""
return self._call_op([x], Mode)
def pstdev(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the population standard deviation of the input tensor. The behavior should conform to
[statistics.pstdev](https://docs.python.org/3/library/statistics.html#statistics.pstdev) in Python standard library.
"""
return self._call_op([x], PStdev)
def pvariance(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the population variance of the input tensor. The behavior should conform to
[statistics.pvariance](https://docs.python.org/3/library/statistics.html#statistics.pvariance) in Python standard library.
"""
return self._call_op([x], PVariance)
def stdev(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the sample standard deviation of the input tensor. The behavior should conform to
[statistics.stdev](https://docs.python.org/3/library/statistics.html#statistics.stdev) in Python standard library.
"""
return self._call_op([x], Stdev)
def variance(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the sample variance of the input tensor. The behavior should conform to
[statistics.variance](https://docs.python.org/3/library/statistics.html#statistics.variance) in Python standard library.
"""
return self._call_op([x], Variance)
def covariance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Calculate the covariance of x and y. The behavior should conform to
[statistics.covariance](https://docs.python.org/3/library/statistics.html#statistics.covariance) in Python standard library.
"""
return self._call_op([x, y], Covariance)
def correlation(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Calculate the correlation of x and y. The behavior should conform to
[statistics.correlation](https://docs.python.org/3/library/statistics.html#statistics.correlation) in Python standard library.
"""
return self._call_op([x, y], Correlation)
def linear_regression(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Calculate the linear regression of x and y. The behavior should conform to
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
"""
return self._call_op([x, y], Regression)
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
@@ -138,7 +186,6 @@ class IModel(nn.Module):
...
# An computation function. Example:
# def computation(state: State, x: list[torch.Tensor]):
# out_0 = state.median(x[0])
@@ -150,6 +197,10 @@ TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
:param error: The error tolerance for the computation.
:return: A tuple of State and Model. The Model is a torch model that can be used for exporting to onnx.
State is a container for intermediate results of computation, which can be useful when debugging.
"""
state = State(error)
@@ -160,5 +211,4 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)
return state, Model

View File

@@ -10,7 +10,7 @@ IsResultPrecise = torch.Tensor
class Operation(ABC):
def __init__(self, result: torch.Tensor, error: float):
self.result = result
self.result = torch.nn.Parameter(data=result, requires_grad=False)
self.error = error
@abstractclassmethod
@@ -57,6 +57,7 @@ class Median(Operation):
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
return cls(x[0], error)
@@ -241,9 +242,11 @@ class Covariance(Operation):
y_1d = to_1d(y)
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
super().__init__(result, error)
@classmethod
@@ -288,6 +291,7 @@ class Correlation(Operation):
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
super().__init__(result, error)
@classmethod