mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
revamp zkstats logic/ flow
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,8 +1,4 @@
|
||||
{
|
||||
"x": [
|
||||
0, 1, 2, 3, 4
|
||||
],
|
||||
"y": [
|
||||
2.0, 5.2, 47.4, 23.6, 24.8
|
||||
]
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,38 +0,0 @@
|
||||
{
|
||||
"col_1": [
|
||||
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
|
||||
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
|
||||
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
|
||||
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
|
||||
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
|
||||
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
|
||||
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
|
||||
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
|
||||
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
|
||||
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
|
||||
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
|
||||
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
|
||||
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
|
||||
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
|
||||
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
|
||||
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
|
||||
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
|
||||
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
|
||||
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
|
||||
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
|
||||
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
|
||||
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
|
||||
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
|
||||
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
|
||||
],
|
||||
"col_2": [
|
||||
19.2, 54.1, 16.5, 24.8, 42.7, 18.9, 78.8, 54.4, 27.4, 76.2, 43.4, 20.9, 2.9,
|
||||
30.4, 21.4, 2.0, 5.6, 33.5, 4.8, 4.7, 57.5, 23.5, 40.1, 83.1, 78.9, 95.1,
|
||||
41.1, 59.0, 59.2, 91.1, 20.9, 67.6, 44.1, 91.3, 89.9, 85.7, 92.6, 67.1,
|
||||
90.0, 29.5, 40.9, 96.8, 2.3, 57.9, 93.2, 83.9, 10.4, 75.1, 24.2, 22.9, 21.2,
|
||||
26.9, 96.8, 89.0, 68.0, 16.1, 90.1, 1.7, 79.6, 98.5, 21.3, 79.5, 9.2, 97.9,
|
||||
21.6, 4.2, 66.1, 53.8, 79.5, 60.6, 66.9, 39.5, 50.1, 66.1, 96.4, 80.5, 61.9,
|
||||
44.4, 84.8, 64.8, 23.2, 7.1, 21.1, 90.5, 29.2, 1.4, 54.8, 9.8, 41.1, 45.2,
|
||||
56.6, 48.2, 61.3, 62.9, 2.7, 33.2, 62.5, 40.9, 33.6, 50.1
|
||||
]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,28 +1,8 @@
|
||||
{
|
||||
"col_name": [
|
||||
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
|
||||
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
|
||||
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
|
||||
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
|
||||
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
|
||||
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
|
||||
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
|
||||
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
|
||||
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
|
||||
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
|
||||
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
|
||||
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
|
||||
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
|
||||
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
|
||||
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
|
||||
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
|
||||
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
|
||||
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
|
||||
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
|
||||
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
|
||||
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
|
||||
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
|
||||
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
|
||||
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
|
||||
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
|
||||
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
|
||||
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
|
||||
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
|
||||
]
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
250
examples/where/where+harmomean/where+harmomean.ipynb
Normal file
250
examples/where/where+harmomean/where+harmomean.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,29 +1,8 @@
|
||||
{
|
||||
"col_name": [
|
||||
15.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0, 41.0,
|
||||
47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0, 51.0,
|
||||
48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0, 51.0,
|
||||
56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0, 62.0,
|
||||
39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0, 52.0,
|
||||
45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0, 45.0,
|
||||
45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0, 34.0,
|
||||
66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0, 30.0,
|
||||
33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0, 62.0,
|
||||
38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0, 52.0,
|
||||
41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0, 52.0,
|
||||
58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0, 64.0,
|
||||
46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0, 69.0,
|
||||
27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0, 55.0,
|
||||
58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0, 45.0,
|
||||
40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0, 39.0,
|
||||
58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0, 34.0,
|
||||
50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0, 22.0,
|
||||
50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0, 44.0,
|
||||
55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0, 43.0,
|
||||
50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0, 41.0,
|
||||
41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0, 64.0,
|
||||
76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0, 50.0,
|
||||
31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0, 40.0,
|
||||
25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
|
||||
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
|
||||
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
|
||||
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
|
||||
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
|
||||
]
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -24,4 +24,4 @@ def column_2():
|
||||
|
||||
@pytest.fixture
|
||||
def scales():
|
||||
return [6]
|
||||
return [7]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import json
|
||||
from typing import Type, Sequence, Optional
|
||||
from typing import Type, Sequence, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from zkstats.core import prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment
|
||||
from zkstats.computation import IModel
|
||||
from zkstats.core import create_dummy,prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment, verifier_define_calculation
|
||||
from zkstats.computation import IModel, State, computation_to_model
|
||||
|
||||
|
||||
DEFAULT_POSSIBLE_SCALES = list(range(20))
|
||||
@@ -22,17 +22,20 @@ def data_to_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
|
||||
column: d.tolist()
|
||||
for column, d in zip(column_names, data)
|
||||
}
|
||||
print('columnnnn: ', column_to_data)
|
||||
with open(data_path, "w") as f:
|
||||
json.dump(column_to_data, f)
|
||||
return column_to_data
|
||||
|
||||
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
def compute(
|
||||
basepath: Path,
|
||||
data: list[torch.Tensor],
|
||||
model: Type[IModel],
|
||||
# computation: TComputation,
|
||||
scales_params: Optional[Sequence[int]] = None,
|
||||
selected_columns_params: Optional[list[str]] = None,
|
||||
# error:float = 1.0
|
||||
) -> None:
|
||||
sel_data_path = basepath / "comb_data.json"
|
||||
model_path = basepath / "model.onnx"
|
||||
@@ -45,6 +48,14 @@ def compute(
|
||||
data_path = basepath / "data.json"
|
||||
data_commitment_path = basepath / "commitments.json"
|
||||
|
||||
# verifier_model_path = basepath / "verifier_model.onnx"
|
||||
# verifier_compiled_model_path = basepath / "verifier_model.compiled"
|
||||
# prover_model_path = basepath / "prover_model.onnx"
|
||||
# prover_compiled_model_path = basepath / "prover_model.compiled"
|
||||
# precal_witness_path = basepath / "precal_witness_arr.json"
|
||||
# dummy_data_path = basepath / "dummy_data.json"
|
||||
# sel_dummy_data_path = basepath / "sel_dummy_data_path.json"
|
||||
|
||||
column_to_data = data_to_file(data_path, data)
|
||||
# If selected_columns_params is None, select all columns
|
||||
if selected_columns_params is None:
|
||||
@@ -60,43 +71,21 @@ def compute(
|
||||
else:
|
||||
scales = scales_params
|
||||
scales_for_commitments = scales_params
|
||||
# create_dummy((data_path), (dummy_data_path))
|
||||
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
|
||||
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, error)
|
||||
|
||||
generate_data_commitment(data_path, scales_for_commitments, data_commitment_path)
|
||||
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
|
||||
|
||||
prover_gen_settings(
|
||||
data_path=data_path,
|
||||
selected_columns=selected_columns,
|
||||
sel_data_path=str(sel_data_path),
|
||||
prover_model=model,
|
||||
prover_model_path=str(model_path),
|
||||
scale=scales,
|
||||
mode="resources",
|
||||
settings_path=str(settings_path),
|
||||
)
|
||||
# No need, since verifier & prover share the same onnx
|
||||
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False,error)
|
||||
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
|
||||
|
||||
setup(
|
||||
str(model_path),
|
||||
str(compiled_model_path),
|
||||
str(settings_path),
|
||||
str(vk_path),
|
||||
str(pk_path),
|
||||
)
|
||||
prover_gen_proof(
|
||||
str(model_path),
|
||||
str(sel_data_path),
|
||||
str(witness_path),
|
||||
str(compiled_model_path),
|
||||
str(settings_path),
|
||||
str(proof_path),
|
||||
str(pk_path),
|
||||
)
|
||||
verifier_verify(
|
||||
str(proof_path),
|
||||
str(settings_path),
|
||||
str(vk_path),
|
||||
selected_columns,
|
||||
data_commitment_path,
|
||||
)
|
||||
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
|
||||
|
||||
prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
|
||||
# print('slett col: ', selected_columns)
|
||||
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
|
||||
|
||||
|
||||
# Error tolerance between zkstats python implementation and python statistics module
|
||||
|
||||
@@ -40,20 +40,20 @@ def nested_computation(state: State, args: list[torch.Tensor]):
|
||||
out_9 = state.correlation(y, z)
|
||||
out_10 = state.linear_regression(x, y)
|
||||
slope, intercept = out_10[0][0][0], out_10[0][1][0]
|
||||
reshaped = torch.tensor([
|
||||
out_0,
|
||||
out_1,
|
||||
out_2,
|
||||
out_3,
|
||||
out_4,
|
||||
out_5,
|
||||
out_6,
|
||||
out_7,
|
||||
out_8,
|
||||
out_9,
|
||||
slope,
|
||||
intercept,
|
||||
]).reshape(1,-1,1)
|
||||
reshaped = torch.cat((
|
||||
out_0.unsqueeze(0),
|
||||
out_1.unsqueeze(0),
|
||||
out_2.unsqueeze(0),
|
||||
out_3.unsqueeze(0),
|
||||
out_4.unsqueeze(0),
|
||||
out_5.unsqueeze(0),
|
||||
out_6.unsqueeze(0),
|
||||
out_7.unsqueeze(0),
|
||||
out_8.unsqueeze(0),
|
||||
out_9.unsqueeze(0),
|
||||
slope.unsqueeze(0),
|
||||
intercept.unsqueeze(0),
|
||||
)).reshape(1,-1,1)
|
||||
out_10 = state.mean(reshaped)
|
||||
return out_10
|
||||
|
||||
@@ -63,10 +63,12 @@ def nested_computation(state: State, args: list[torch.Tensor]):
|
||||
[0.1],
|
||||
)
|
||||
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
|
||||
state, model = computation_to_model(nested_computation, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(nested_computation, precal_witness_path,True, error)
|
||||
x, y, z = column_0, column_1, column_2
|
||||
compute(tmp_path, [x, y, z], model, scales)
|
||||
# There are 11 ops in the computation
|
||||
|
||||
assert state.current_op_index == 12
|
||||
|
||||
ops = state.ops
|
||||
@@ -157,8 +159,8 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
|
||||
def where_and_op(state: State, args: list[torch.Tensor]):
|
||||
x = args[0]
|
||||
return op_type(state, state.where(condition(x), x))
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(where_and_op, precal_witness_path,True, error)
|
||||
compute(tmp_path, [column], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
@@ -185,8 +187,8 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
|
||||
filtered_x = state.where(condition_x, x)
|
||||
filtered_y = state.where(condition_x, y)
|
||||
return op_type(state, filtered_x, filtered_y)
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(where_and_op, precal_witness_path, True ,error)
|
||||
compute(tmp_path, [column_0, column_1], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_integration_select_partial_columns(tmp_path, column_0, column_1, error,
|
||||
|
||||
def simple_computation(state, x):
|
||||
return state.mean(x[0])
|
||||
|
||||
_, model = computation_to_model(simple_computation, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
_, model = computation_to_model(simple_computation,precal_witness_path, True, error)
|
||||
# gen settings, setup, prove, verify
|
||||
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
|
||||
|
||||
@@ -5,46 +5,46 @@ import pytest
|
||||
|
||||
import torch
|
||||
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
|
||||
from zkstats.computation import IModel, IsResultPrecise
|
||||
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
|
||||
|
||||
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
|
||||
(Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
|
||||
(GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# Be more tolerant for HarmonicMean
|
||||
(HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
|
||||
# Be less tolerant for Mode
|
||||
(Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
|
||||
(PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
|
||||
(Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
|
||||
]
|
||||
)
|
||||
def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
|
||||
# @pytest.mark.parametrize(
|
||||
# "op_type, expected_func, error",
|
||||
# [
|
||||
# (Mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Median, statistics.median, ERROR_CIRCUIT_DEFAULT),
|
||||
# (GeometricMean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# # Be more tolerant for HarmonicMean
|
||||
# (HarmonicMean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
|
||||
# # Be less tolerant for Mode
|
||||
# (Mode, statistics.mode, ERROR_CIRCUIT_STRICT),
|
||||
# (PStdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
|
||||
# (PVariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
|
||||
# (Variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
|
||||
# ]
|
||||
# )
|
||||
# def test_ops_1_parameter(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
|
||||
(Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
|
||||
]
|
||||
)
|
||||
def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
|
||||
# @pytest.mark.parametrize(
|
||||
# "op_type, expected_func, error",
|
||||
# [
|
||||
# (Covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
|
||||
# (Correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
|
||||
# ]
|
||||
# )
|
||||
# def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float], scales: list[float]):
|
||||
# run_test_ops(tmp_path, op_type, expected_func, error, scales, [column_0, column_1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
ERROR_CIRCUIT_RELAXED
|
||||
1.0
|
||||
]
|
||||
)
|
||||
def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, scales: list[float]):
|
||||
@@ -55,12 +55,15 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
|
||||
actual_res = regression.result
|
||||
assert_result(expected_res.slope, actual_res[0][0][0])
|
||||
assert_result(expected_res.intercept, actual_res[0][1][0])
|
||||
print("slope: ", actual_res[0][0][0])
|
||||
print('intercept: ',actual_res[0][1][0] )
|
||||
class Model(IModel):
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return regression.ezkl(x), regression.result
|
||||
compute(tmp_path, columns, Model, scales)
|
||||
|
||||
|
||||
|
||||
def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[list[float]], float], error: float, scales: list[float], columns: list[torch.Tensor]):
|
||||
op = op_type.create(columns, error)
|
||||
expected_res = expected_func(*[column.tolist() for column in columns])
|
||||
|
||||
@@ -25,6 +25,7 @@ from .ops import (
|
||||
|
||||
|
||||
DEFAULT_ERROR = 0.01
|
||||
MagicNumber = 99.999
|
||||
|
||||
|
||||
class State:
|
||||
@@ -47,6 +48,7 @@ class State:
|
||||
self.precal_witness_path: str = None
|
||||
self.precal_witness:dict = {}
|
||||
self.isProver:bool = None
|
||||
self.op_dict:dict={}
|
||||
|
||||
def set_ready_for_exporting_onnx(self) -> None:
|
||||
self.current_op_index = 0
|
||||
@@ -151,19 +153,73 @@ class State:
|
||||
if self.current_op_index is None:
|
||||
# for prover
|
||||
if self.isProver:
|
||||
print('Prover side')
|
||||
# print('Prover side create')
|
||||
op = op_type.create(x, self.error)
|
||||
if isinstance(op,Mean):
|
||||
self.precal_witness['Mean'] = [op.result.data.item()]
|
||||
|
||||
# Single witness aka result
|
||||
if isinstance(op,Mean) or isinstance(op,GeometricMean) or isinstance(op, HarmonicMean) or isinstance(op, Mode):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
|
||||
self.op_dict[op_class_str]+=1
|
||||
elif isinstance(op, Median):
|
||||
self.precal_witness['Median'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
if 'Median' not in self.op_dict:
|
||||
self.precal_witness['Median_0'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
self.op_dict['Median']=1
|
||||
else:
|
||||
self.precal_witness['Median_'+str(self.op_dict['Median'])] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
self.op_dict['Median']+=1
|
||||
# std + variance stuffs
|
||||
elif isinstance(op, PStdev) or isinstance(op, PVariance) or isinstance(op, Stdev) or isinstance(op, Variance):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
|
||||
self.op_dict[op_class_str]+=1
|
||||
elif isinstance(op, Covariance):
|
||||
if 'Covariance' not in self.op_dict:
|
||||
self.precal_witness['Covariance_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
||||
self.op_dict['Covariance']=1
|
||||
else:
|
||||
self.precal_witness['Covariance_'+str(self.op_dict['Covariance'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
||||
self.op_dict['Covariance']+=1
|
||||
elif isinstance(op, Correlation):
|
||||
if 'Correlation' not in self.op_dict:
|
||||
self.precal_witness['Correlation_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
||||
self.op_dict['Correlation']=1
|
||||
else:
|
||||
self.precal_witness['Correlation_'+str(self.op_dict['Correlation'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
||||
self.op_dict['Correlation']+=1
|
||||
elif isinstance(op, Regression):
|
||||
result_array = []
|
||||
for ele in op.result.data[0]:
|
||||
result_array.append(ele[0].item())
|
||||
if 'Regression' not in self.op_dict:
|
||||
self.precal_witness['Regression_0'] = [result_array]
|
||||
self.op_dict['Regression']=1
|
||||
else:
|
||||
self.precal_witness['Regression_'+str(self.op_dict['Regression'])] = [result_array]
|
||||
self.op_dict['Regression']+=1
|
||||
# for verifier
|
||||
else:
|
||||
print('Verifier side')
|
||||
# print('Verifier side create')
|
||||
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
||||
op = op_type.create(x, self.error, precal_witness)
|
||||
print('finish create')
|
||||
op = op_type.create(x, self.error, precal_witness, self.op_dict)
|
||||
# dont need to include Where
|
||||
if not isinstance(op, Where):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.op_dict[op_class_str]+=1
|
||||
self.ops.append(op)
|
||||
if isinstance(op, Where):
|
||||
return torch.where(x[0], x[1], MagicNumber)
|
||||
return op.result
|
||||
else:
|
||||
# Copy the current op index to a local variable since self.current_op_index will be incremented.
|
||||
@@ -190,7 +246,7 @@ class State:
|
||||
# else, return only result
|
||||
|
||||
if current_op_index == len_ops - 1:
|
||||
print('final op: ', op)
|
||||
# print('final op: ', op)
|
||||
# Sanity check for length of self.ops and self.bools
|
||||
len_bools = len(self.bools)
|
||||
if len_ops != len_bools:
|
||||
@@ -198,13 +254,10 @@ class State:
|
||||
is_precise_aggregated = torch.tensor(1.0)
|
||||
for i in range(len_bools):
|
||||
res = self.bools[i]()
|
||||
# print("hey computation: ", i)
|
||||
# print('self.ops: ', self.ops[i])
|
||||
# print('res: ', res)
|
||||
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
|
||||
if isinstance(op, Where):
|
||||
# return as where result
|
||||
return is_precise_aggregated, op.result+x[1]-x[1]
|
||||
# print('Only where')
|
||||
return is_precise_aggregated, torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
|
||||
else:
|
||||
if self.isProver:
|
||||
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
||||
@@ -216,13 +269,9 @@ class State:
|
||||
else:
|
||||
# for where
|
||||
if isinstance(op, Where):
|
||||
return (op.result+x[1]-x[1])
|
||||
# print('many ops incl where')
|
||||
return torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
|
||||
else:
|
||||
# return single float number
|
||||
# return torch.where(x[0], x[1], 9999999)
|
||||
# print('oppy else: ', op)
|
||||
# print('is check else: ', isinstance(op,Mean))
|
||||
# self.aggregate_witness.append(op.result)
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
|
||||
|
||||
427
zkstats/ops.py
427
zkstats/ops.py
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
# boolean: either 1.0 or 0.0
|
||||
IsResultPrecise = torch.Tensor
|
||||
MagicNumber = 9999999.0
|
||||
MagicNumber = 99.999
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
@@ -26,17 +26,18 @@ class Operation(ABC):
|
||||
|
||||
class Where(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Where':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Where':
|
||||
# here error is trivial, but here to conform to other functions
|
||||
return cls(torch.where(x[0],x[1], MagicNumber ),error)
|
||||
# just dummy result, since not using it anyway because we dont want to expose direct result from where
|
||||
return cls(torch.tensor(1),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
bool_array = torch.logical_or(torch.logical_and(x[0], x[1]==self.result), torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
return torch.sum(bool_array.float())==x[1].size()[1]
|
||||
return torch.tensor(True)
|
||||
|
||||
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Mean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {} ) -> 'Mean':
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
if precal_witness is None:
|
||||
# this is prover
|
||||
@@ -45,16 +46,15 @@ class Mean(Operation):
|
||||
else:
|
||||
# this is verifier
|
||||
# print('verrrr')
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Mean']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
print("mean tensor arr: ", tensor_arr)
|
||||
return cls(tensor_arr[0], error)
|
||||
if 'Mean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mean_'+str(op_dict['Mean'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x = torch.where(x==MagicNumber, 0.0, x)
|
||||
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*self.result*size)
|
||||
|
||||
@@ -71,7 +71,7 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class Median(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None ):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict= {} ):
|
||||
if precal_witness is None:
|
||||
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
|
||||
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
|
||||
@@ -85,84 +85,91 @@ class Median(Operation):
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
else:
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Median']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
super().__init__(tensor_arr[0], error)
|
||||
self.lower = torch.nn.Parameter(data = tensor_arr[1], requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = tensor_arr[2], requires_grad=False)
|
||||
if 'Median' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][2]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None ) -> 'Median':
|
||||
if precal_witness is None:
|
||||
return cls(x[0], error)
|
||||
else:
|
||||
tensor_arr = []
|
||||
for ele in precal_witness['Median']:
|
||||
tensor_arr.append(torch.tensor(ele))
|
||||
return cls(tensor_arr[0],error, precal_witness)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict= {} ) -> 'Median':
|
||||
return cls(x[0],error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
old_size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
min_x = torch.min(x)
|
||||
x = torch.where(x==MagicNumber,min_x-1, x)
|
||||
|
||||
# since within 1%, we regard as same value
|
||||
count_less = torch.sum((x < self.result).float())-(old_size-size)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
count_less = torch.sum(torch.where(x < self.result, 1.0, 0.0))-(old_size-size)
|
||||
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
|
||||
half_size = torch.floor(torch.div(size, 2))
|
||||
|
||||
# print('hhhh: ', half_size)
|
||||
less_cons = count_less<half_size+size%2
|
||||
more_cons = count_less+count_equal>half_size
|
||||
|
||||
# For count_equal == 0
|
||||
lower_exist = torch.sum((x==self.lower).float())>0
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_size
|
||||
upper_exist = torch.sum((x==self.upper).float())>0
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_size
|
||||
lower_exist = torch.sum(torch.where(x==self.lower, 1.0, 0.0))>0
|
||||
lower_cons = torch.sum(torch.where(x>self.lower, 1.0, 0.0))==half_size
|
||||
upper_exist = torch.sum(torch.where(x==self.upper, 1.0, 0.0))>0
|
||||
upper_cons = torch.sum(torch.where(x<self.upper, 1.0, 0.0))==half_size
|
||||
bound = count_less== half_size
|
||||
# 0.02 since 2*0.01
|
||||
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
|
||||
|
||||
median_in_cons = torch.logical_and(less_cons, more_cons)
|
||||
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
|
||||
return torch.where(count_equal==0, median_out_cons, median_in_cons)
|
||||
return torch.where(count_equal==0.0, median_out_cons, median_in_cons)
|
||||
|
||||
|
||||
class GeometricMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'GeometricMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'GeometricMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_'+str(op_dict['GeometricMean'])][0]), error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x = torch.where(x==MagicNumber, 1.0, x)
|
||||
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
|
||||
|
||||
|
||||
class HarmonicMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'HarmonicMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict = {}) -> 'HarmonicMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'HarmonicMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_'+str(op_dict['HarmonicMean'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
# just make it really big so that 1/x goes to zero for element that gets filtered out
|
||||
x = torch.where(x==MagicNumber, x*x, x)
|
||||
return torch.abs((self.result*torch.sum(torch.div(1.0, x))) - size)<=torch.abs(self.error*size)
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
return torch.abs((self.result*torch.sum(torch.where(x==MagicNumber, 0.0, torch.div(1.0, x)))) - size)<=torch.abs(self.error*size)
|
||||
|
||||
|
||||
def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
@@ -213,12 +220,19 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
|
||||
class Mode(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Mode':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if 'Mode' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mode_'+str(op_dict['Mode'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
@@ -226,194 +240,258 @@ class Mode(Operation):
|
||||
min_x = torch.min(x)
|
||||
old_size = x.size()[1]
|
||||
x = torch.where(x==MagicNumber, min_x-1, x)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
|
||||
|
||||
count_check = 0
|
||||
for ele in x[0]:
|
||||
bool1 = torch.sum((x==ele[0]).float())<=count_equal
|
||||
bool1 = torch.sum(torch.where(x==ele[0], 1.0, 0.0))<=count_equal
|
||||
bool2 = ele[0]==min_x-1
|
||||
count_check += torch.logical_or(bool1, bool2)
|
||||
return count_check ==old_size
|
||||
|
||||
|
||||
class PStdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'PStdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'PStdev':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PStdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class PVariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'PVariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][1]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'PVariance':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'PVariance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Stdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Stdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Stdev':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Stdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class Variance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Variance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Variance':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Variance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Covariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:dict = None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
|
||||
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Covariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][2]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Covariance':
|
||||
return cls(x[0], x[1], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Covariance':
|
||||
return cls(x[0], x[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
# only x_fil_mean is enough, no need for y_fil_mean since it will multiply 0 anyway
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
|
||||
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
|
||||
|
||||
return torch.logical_and(
|
||||
torch.logical_and(size_x==size_y,torch.logical_and(x_mean_cons,y_mean_cons)),
|
||||
torch.abs(torch.sum((x_fil_mean-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*self.result*(size_x-1)
|
||||
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*self.result)<=torch.abs(self.error*self.result*(size_x-1))
|
||||
)
|
||||
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def stdev_for_corr(x_fil_mean:torch.Tensor,size_x:torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
def stdev_for_corr(x_adj_mean:torch.Tensor, size_x:torch.Tensor, x_std: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(x_fil_mean-x_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
|
||||
, x_std)
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def covariance_for_corr(x_fil_mean: torch.Tensor,y_fil_mean: torch.Tensor,size_x:torch.Tensor, size_y:torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
def covariance_for_corr(x_adj_mean: torch.Tensor,y_adj_mean: torch.Tensor,size_x:torch.Tensor, cov: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(y_fil_mean-y_mean))-(size_x-1)*cov)<error*cov*(size_x-1)
|
||||
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*cov)<=torch.abs(error*cov*(size_x-1))
|
||||
, cov)
|
||||
|
||||
|
||||
class Correlation(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
|
||||
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
|
||||
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
|
||||
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
|
||||
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Correlation' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][5]), requires_grad=False)
|
||||
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float) -> 'Correlation':
|
||||
return cls(args[0], args[1], error)
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Correlation':
|
||||
return cls(args[0], args[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
y_fil_mean = torch.where(y==MagicNumber, self.y_mean, y)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
|
||||
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
|
||||
|
||||
miscel_cons = torch.logical_and(size_x==size_y, torch.logical_and(x_mean_cons, y_mean_cons))
|
||||
bool1, cov = covariance_for_corr(x_fil_mean,y_fil_mean,size_x, size_y, self.cov, self.x_mean, self.y_mean, self.error)
|
||||
bool2, x_std = stdev_for_corr( x_fil_mean, size_x, self.x_std, self.x_mean, self.error)
|
||||
bool3, y_std = stdev_for_corr( y_fil_mean, size_y, self.y_std, self.y_mean, self.error)
|
||||
bool4 = torch.abs(cov - self.result*x_std*y_std)<=self.error*cov
|
||||
bool1, cov = covariance_for_corr(x_adj_mean,y_adj_mean,size_x, self.cov, self.error)
|
||||
bool2, x_std = stdev_for_corr( x_adj_mean, size_x, self.x_std, self.error)
|
||||
bool3, y_std = stdev_for_corr( y_adj_mean, size_y, self.y_std, self.error)
|
||||
# this is correlation constraint
|
||||
bool4 = torch.abs(cov - self.result*x_std*y_std)<=torch.abs(self.error*cov)
|
||||
return torch.logical_and(torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4)), miscel_cons)
|
||||
|
||||
|
||||
@@ -422,34 +500,51 @@ def stacked_x(args: list[float]):
|
||||
|
||||
|
||||
class Regression(Operation):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:dict=None, op_dict:dict = {}):
|
||||
if precal_witness is None:
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
# print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if 'Regression' not in op_dict:
|
||||
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
|
||||
else:
|
||||
result = torch.tensor(precal_witness['Regression_'+str(op_dict['Regression'])]).reshape(1,-1,1)
|
||||
|
||||
# for ele in precal_witness['Regression']:
|
||||
# precal_witness_arr.append(torch.tensor(ele))
|
||||
# print('resultopppp: ', result)
|
||||
super().__init__(result,error)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float) -> 'Regression':
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:dict = None, op_dict:dict = {}) -> 'Regression':
|
||||
xs = args[:-1]
|
||||
y = args[-1]
|
||||
return cls(xs, y, error)
|
||||
return cls(xs, y, error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# infer y from the last parameter
|
||||
y = args[-1]
|
||||
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
|
||||
y = torch.where(y==MagicNumber,0.0, y)
|
||||
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
|
||||
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
|
||||
x_t = torch.transpose(x_one, 1, 2)
|
||||
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))
|
||||
|
||||
left = x_t @ x_one @ self.result - x_t @ y
|
||||
right = self.error*x_t @ y
|
||||
abs_left = torch.where(left>=0, left, -left)
|
||||
abs_right = torch.where(right>=0, right, -right)
|
||||
return torch.where(torch.sum(torch.where(abs_left<=abs_right, 1.0, 0.0))==torch.tensor(2.0), 1.0, 0.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user