This commit is contained in:
JernKunpittaya
2024-05-10 17:49:00 +07:00
parent cabf6f8c47
commit 58fbf30d19

View File

@@ -10,35 +10,35 @@ from zkstats.computation import IModel, IsResultPrecise, State, computation_to_m
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
# @pytest.mark.parametrize(
# "op_type, expected_func, error",
# [
# (Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
# (Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
# (GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
# # Be more tolerant for HarmonicMean
# (HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
# # Be less tolerant for Mode
# (Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
# (PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
# (PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
# (Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
# (Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
# ]
# )
# def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
@pytest.mark.parametrize(
"op_type, expected_func, error",
[
(Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
(Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
(GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
# Be more tolerant for HarmonicMean
(HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
# Be less tolerant for Mode
(Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
(PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
(PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
(Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
(Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
]
)
def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
# @pytest.mark.parametrize(
# "op_type, expected_func, error",
# [
# (Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
# (Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
# ]
# )
# def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
@pytest.mark.parametrize(
"op_type, expected_func, error",
[
(Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
(Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
]
)
def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
@pytest.mark.parametrize(
@@ -55,8 +55,6 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
actual_res = regression.result
assert_result(expected_res.slope, actual_res[0][0][0])
assert_result(expected_res.intercept, actual_res[0][1][0])
print("slope: ", actual_res[0][0][0])
print('intercept: ',actual_res[0][1][0] )
class Model(IModel):
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return regression.ezkl(x), regression.result