mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
test_ops
This commit is contained in:
@@ -10,35 +10,35 @@ from zkstats.computation import IModel, IsResultPrecise, State, computation_to_m
|
||||
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "op_type, expected_func, error",
|
||||
# [
|
||||
# (Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
|
||||
# (GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# # Be more tolerant for HarmonicMean
|
||||
# (HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
|
||||
# # Be less tolerant for Mode
|
||||
# (Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
|
||||
# (PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
|
||||
# (PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
|
||||
# ]
|
||||
# )
|
||||
# def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
|
||||
(Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
|
||||
(GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# Be more tolerant for HarmonicMean
|
||||
(HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
|
||||
# Be less tolerant for Mode
|
||||
(Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
|
||||
(PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
|
||||
(Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
|
||||
]
|
||||
)
|
||||
def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "op_type, expected_func, error",
|
||||
# [
|
||||
# (Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
|
||||
# (Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
|
||||
# ]
|
||||
# )
|
||||
# def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
|
||||
(Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
|
||||
]
|
||||
)
|
||||
def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -55,8 +55,6 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
|
||||
actual_res = regression.result
|
||||
assert_result(expected_res.slope, actual_res[0][0][0])
|
||||
assert_result(expected_res.intercept, actual_res[0][1][0])
|
||||
print("slope: ", actual_res[0][0][0])
|
||||
print('intercept: ',actual_res[0][1][0] )
|
||||
class Model(IModel):
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return regression.ezkl(x), regression.result
|
||||
|
||||
Reference in New Issue
Block a user