Merge pull request #34 from ZKStats/tmp/recheck_func

Tmp/recheck func
This commit is contained in:
JernKunpittaya
2024-05-14 14:42:23 +07:00
committed by GitHub
48 changed files with 2661 additions and 2196 deletions

View File

@@ -56,10 +56,10 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
# Compute the median of the second column
median2 = s.median(data[1])
# Compute the mean of the medians
return s.mean(torch.Tensor([median1, median2]).reshape(1, -1, 1))
return s.mean(torch.cat((median1.unsqueeze(0), median2.unsqueeze(0))).reshape(1,-1,1))
```
> NOTE: `reshape` is required for now since input must be in shape `[1, data_size, 1]` for now. It should be addressed in the future
> NOTE: `reshape` is required for now since input must be in shape `[1, data_size, 1]` for now. It should be addressed in the future, the same for torch.cat(), and unsqueeze(), we will write wrapper in the future.
#### Torch Operations
@@ -88,7 +88,7 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
### Proof Generation and Verification
The flow between data providers and users is as follows:
![zkstats-lib-flow](./assets/zkstats-lib.png)
![zkstats-lib-flow](./assets/zkstats-flow.png)
#### Data Provider: generate data commitments
@@ -109,12 +109,16 @@ When generating a proof, since dataset might contain floating points, data provi
#### Both: derive PyTorch model from the computation
When a user wants to request a data provider to generate a proof for their defined computation, the user must send the data provider first. Then, both the data provider and the user transform the model to necessary settings, respectively.
When a user wants to request a data provider to generate a proof for their defined computation, the user must let the data provider know what the computation is. Then, the data provider, with real dataset, will generate model from computation using computation_to_model() method. Since we use witness approach (described more in Note section below), the data provider is required to send the pre-calculated witness back to verifier. Then, verifier, with pre-calculated witness, generates the model from computation to be the exact model as prover.
Note here, that we can also just let prover generate model, and then send that model to verifier directly. However, to make sure that the prover's model actually comes from verifier's computation, it's better to have verifier generates the model itself from its computation, but just with the help of pre-calculated witness.
```python
from zkstats.core import computation_to_model
_, model = computation_to_model(user_computation)
# For prover: generate prover_model, and write to precal_witness file
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, error)
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, error)
```
#### Data Provider: generate settings
@@ -204,14 +208,14 @@ See our jupyter notebook for [examples](./examples/).
## Benchmarks
TOFIX: Update the benchmark method. See more in issues.
See our jupyter notebook for [benchmarks](./benchmark/).
TODO: clean benchmark
## Note
- We implement using witness approach instead of directly calculating the value in circuit. This sometimes allows us to not calculate stuffs like division or exponential which requires larger scale in settings. (If we don't use larger scale in those cases, the accuracy will be very bad)
- Dummy data to feed in verifier onnx file needs to have same shape as the private dataset, but can be filled with any value (we just randomize it to be uniform 1-10 with 1 decimal).
- For Mode function, if there are more than 1 value possible, we just output one of them (the one that first encountered), conforming to the spec of statistics.mode in python lib (https://docs.python.org/3.9/library/statistics.html#statistics.mode)
- For Mode function, if there are more than 1 value possible, we just outputthe one that first encountered, conforming to the spec of statistics.mode in python lib (https://docs.python.org/3.9/library/statistics.html#statistics.mode)
## Legacy

BIN
assets/zkstats-flow.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 115 KiB

View File

@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6, 7],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,8 +0,0 @@
{
"x": [
0, 1, 2, 3, 4
],
"y": [
2.0, 5.2, 47.4, 23.6, 24.8
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,38 +0,0 @@
{
"col_1": [
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
],
"col_2": [
19.2, 54.1, 16.5, 24.8, 42.7, 18.9, 78.8, 54.4, 27.4, 76.2, 43.4, 20.9, 2.9,
30.4, 21.4, 2.0, 5.6, 33.5, 4.8, 4.7, 57.5, 23.5, 40.1, 83.1, 78.9, 95.1,
41.1, 59.0, 59.2, 91.1, 20.9, 67.6, 44.1, 91.3, 89.9, 85.7, 92.6, 67.1,
90.0, 29.5, 40.9, 96.8, 2.3, 57.9, 93.2, 83.9, 10.4, 75.1, 24.2, 22.9, 21.2,
26.9, 96.8, 89.0, 68.0, 16.1, 90.1, 1.7, 79.6, 98.5, 21.3, 79.5, 9.2, 97.9,
21.6, 4.2, 66.1, 53.8, 79.5, 60.6, 66.9, 39.5, 50.1, 66.1, 96.4, 80.5, 61.9,
44.4, 84.8, 64.8, 23.2, 7.1, 21.1, 90.5, 29.2, 1.4, 54.8, 9.8, 41.1, 45.2,
56.6, 48.2, 61.3, 62.9, 2.7, 33.2, 62.5, 40.9, 33.6, 50.1
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,28 +1,8 @@
{
"col_name": [
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,29 +1,8 @@
{
"col_name": [
15.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0, 41.0,
47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0, 51.0,
48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0, 51.0,
56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0, 62.0,
39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0, 52.0,
45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0, 45.0,
45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0, 34.0,
66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0, 30.0,
33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0, 62.0,
38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0, 52.0,
41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0, 52.0,
58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0, 64.0,
46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0, 69.0,
27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0, 55.0,
58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0, 45.0,
40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0, 39.0,
58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0, 34.0,
50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0, 22.0,
50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0, 44.0,
55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0, 43.0,
50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0, 41.0,
41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0, 64.0,
76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0, 50.0,
31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0, 40.0,
25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -24,4 +24,4 @@ def column_2():
@pytest.fixture
def scales():
return [6]
return [7]

View File

@@ -1,11 +1,11 @@
import json
from typing import Type, Sequence, Optional
from typing import Type, Sequence, Optional, Callable
from pathlib import Path
import torch
from zkstats.core import prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment
from zkstats.computation import IModel
from zkstats.core import create_dummy,prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment, verifier_define_calculation
from zkstats.computation import IModel, State, computation_to_model
DEFAULT_POSSIBLE_SCALES = list(range(20))
@@ -22,17 +22,20 @@ def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, li
column: d.tolist()
for column, d in zip(column_names, data)
}
print('columnnnn: ', column_to_data)
with open(data_path, "w") as f:
json.dump(column_to_data, f)
return column_to_data
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def compute(
basepath: Path,
data: list[torch.Tensor],
model: Type[IModel],
# computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
# error:float = 1.0
) -> None:
sel_data_path = basepath / "comb_data.json"
model_path = basepath / "model.onnx"
@@ -60,43 +63,21 @@ def compute(
else:
scales = scales_params
scales_for_commitments = scales_params
# create_dummy((data_path), (dummy_data_path))
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, error)
generate_data_commitment(data_path, scales_for_commitments, data_commitment_path)
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
prover_gen_settings(
data_path=data_path,
selected_columns=selected_columns,
sel_data_path=str(sel_data_path),
prover_model=model,
prover_model_path=str(model_path),
scale=scales,
mode="resources",
settings_path=str(settings_path),
)
# No need, since verifier & prover share the same onnx
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False,error)
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
setup(
str(model_path),
str(compiled_model_path),
str(settings_path),
str(vk_path),
str(pk_path),
)
prover_gen_proof(
str(model_path),
str(sel_data_path),
str(witness_path),
str(compiled_model_path),
str(settings_path),
str(proof_path),
str(pk_path),
)
verifier_verify(
str(proof_path),
str(settings_path),
str(vk_path),
selected_columns,
data_commitment_path,
)
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
# print('slett col: ', selected_columns)
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
# Error tolerance between zkstats python implementation and python statistics module

View File

@@ -40,33 +40,35 @@ def nested_computation(state: State, args: list[torch.Tensor]):
out_9 = state.correlation(y, z)
out_10 = state.linear_regression(x, y)
slope, intercept = out_10[0][0][0], out_10[0][1][0]
reshaped = torch.tensor([
out_0,
out_1,
out_2,
out_3,
out_4,
out_5,
out_6,
out_7,
out_8,
out_9,
slope,
intercept,
]).reshape(1,-1,1)
reshaped = torch.cat((
out_0.unsqueeze(0),
out_1.unsqueeze(0),
out_2.unsqueeze(0),
out_3.unsqueeze(0),
out_4.unsqueeze(0),
out_5.unsqueeze(0),
out_6.unsqueeze(0),
out_7.unsqueeze(0),
out_8.unsqueeze(0),
out_9.unsqueeze(0),
slope.unsqueeze(0),
intercept.unsqueeze(0),
)).reshape(1,-1,1)
out_10 = state.mean(reshaped)
return out_10
@pytest.mark.parametrize(
"error",
[0.1],
[ERROR_CIRCUIT_DEFAULT],
)
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
state, model = computation_to_model(nested_computation, error)
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(nested_computation, precal_witness_path,True, error)
x, y, z = column_0, column_1, column_2
compute(tmp_path, [x, y, z], model, scales)
# There are 11 ops in the computation
assert state.current_op_index == 12
ops = state.ops
@@ -157,8 +159,8 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
def where_and_op(state: State, args: list[torch.Tensor]):
x = args[0]
return op_type(state, state.where(condition(x), x))
state, model = computation_to_model(where_and_op, error)
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(where_and_op, precal_witness_path,True, error)
compute(tmp_path, [column], model, scales)
res_op = state.ops[-1]
@@ -185,8 +187,8 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
filtered_x = state.where(condition_x, x)
filtered_y = state.where(condition_x, y)
return op_type(state, filtered_x, filtered_y)
state, model = computation_to_model(where_and_op, error)
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(where_and_op, precal_witness_path, True ,error)
compute(tmp_path, [column_0, column_1], model, scales)
res_op = state.ops[-1]

View File

@@ -71,14 +71,14 @@ def test_integration_select_partial_columns(tmp_path, column_0, column_1, error,
def simple_computation(state, x):
return state.mean(x[0])
_, model = computation_to_model(simple_computation, error)
precal_witness_path = tmp_path / "precal_witness_path.json"
_, model = computation_to_model(simple_computation,precal_witness_path, True, error)
# gen settings, setup, prove, verify
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
def test_csv_data(tmp_path, column_0, column_1, error, scales):
data_json_path = tmp_path / "data.csv"
data_json_path = tmp_path / "data.json"
data_csv_path = tmp_path / "data.csv"
data_json = data_to_json_file(data_json_path, [column_0, column_1])
json_file_to_csv(data_json_path, data_csv_path)
@@ -92,12 +92,13 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
model_path = tmp_path / "model.onnx"
settings_path = tmp_path / "settings.json"
data_commitment_path = tmp_path / "commitments.json"
precal_witness_path = tmp_path / "precal_witness.json"
# Test: `generate_data_commitment` works with csv
generate_data_commitment(data_csv_path, scales, data_commitment_path)
# Test: `prover_gen_settings` works with csv
_, model_for_proving = computation_to_model(simple_computation, error)
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True,error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
@@ -111,10 +112,9 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
_, model_for_verification = computation_to_model(simple_computation, error)
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False,error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
def json_file_to_csv(data_json_path, data_csv_path):
with open(data_json_path, "r") as f:
data_from_json = json.load(f)
@@ -154,4 +154,4 @@ def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
_preprocess_data_file_to_json(data_json_path, new_data_json_path)
with open(new_data_json_path, "r") as f:
new_data_from_json = json.load(f)
assert new_data_from_json == data_from_json
assert new_data_from_json == data_from_json

View File

@@ -5,7 +5,7 @@ import pytest
import torch
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
from zkstats.computation import IModel, IsResultPrecise
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
@@ -44,7 +44,7 @@ def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tens
@pytest.mark.parametrize(
"error",
[
ERROR_CIRCUIT_RELAXED
ERROR_CIRCUIT_DEFAULT
]
)
def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, scales: list[float]):
@@ -61,6 +61,7 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
compute(tmp_path, columns, Model, scales)
def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[list[float]], float], error: float, scales: list[float], columns: list[torch.Tensor]):
op = op_type.create(columns, error)
expected_res = expected_func(*[column.tolist() for column in columns])

View File

@@ -3,6 +3,7 @@ from typing import Callable, Type, Optional, Union
import torch
from torch import nn
import json
from .ops import (
Operation,
@@ -18,12 +19,12 @@ from .ops import (
Covariance,
Correlation,
Regression,
Where,
IsResultPrecise,
)
DEFAULT_ERROR = 0.01
MagicNumber = 99.999
class State:
@@ -43,6 +44,10 @@ class State:
self.error: float = error
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
self.current_op_index: Optional[int] = None
self.precal_witness_path: str = None
self.precal_witness:dict = {}
self.isProver:bool = None
self.op_dict:dict={}
def set_ready_for_exporting_onnx(self) -> None:
self.current_op_index = 0
@@ -133,19 +138,82 @@ class State:
return self._call_op([x, y], Regression)
# WHERE operation
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
:param filter: A boolean tensor serves as a filter
:param _filter: A boolean tensor serves as a filter
:param x: A tensor to be filtered
:return: filtered tensor
"""
return self._call_op([filter, x], Where)
return torch.where(_filter, x, x-x+MagicNumber)
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
op = op_type.create(x, self.error)
# for prover
if self.isProver:
# print('Prover side create')
op = op_type.create(x, self.error)
# Single witness aka result
if isinstance(op,Mean) or isinstance(op,GeometricMean) or isinstance(op, HarmonicMean) or isinstance(op, Mode):
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
self.op_dict[op_class_str] = 1
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Median):
if 'Median' not in self.op_dict:
self.precal_witness['Median_0'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
self.op_dict['Median']=1
else:
self.precal_witness['Median_'+str(self.op_dict['Median'])] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
self.op_dict['Median']+=1
# std + variance stuffs
elif isinstance(op, PStdev) or isinstance(op, PVariance) or isinstance(op, Stdev) or isinstance(op, Variance):
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str] = 1
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Covariance):
if 'Covariance' not in self.op_dict:
self.precal_witness['Covariance_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
self.op_dict['Covariance']=1
else:
self.precal_witness['Covariance_'+str(self.op_dict['Covariance'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
self.op_dict['Covariance']+=1
elif isinstance(op, Correlation):
if 'Correlation' not in self.op_dict:
self.precal_witness['Correlation_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
self.op_dict['Correlation']=1
else:
self.precal_witness['Correlation_'+str(self.op_dict['Correlation'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
self.op_dict['Correlation']+=1
elif isinstance(op, Regression):
result_array = []
for ele in op.result.data[0]:
result_array.append(ele[0].item())
if 'Regression' not in self.op_dict:
self.precal_witness['Regression_0'] = [result_array]
self.op_dict['Regression']=1
else:
self.precal_witness['Regression_'+str(self.op_dict['Regression'])] = [result_array]
self.op_dict['Regression']+=1
# for verifier
else:
# print('Verifier side create')
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
op = op_type.create(x, self.error, precal_witness, self.op_dict)
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.op_dict[op_class_str] = 1
else:
self.op_dict[op_class_str]+=1
self.ops.append(op)
return op.result
else:
@@ -171,7 +239,9 @@ class State:
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result
if current_op_index == len_ops - 1:
# print('final op: ', op)
# Sanity check for length of self.ops and self.bools
len_bools = len(self.bools)
if len_ops != len_bools:
@@ -180,13 +250,15 @@ class State:
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
return is_precise_aggregated, op.result
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
elif current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
# It's not the last operation, just return the result
return op.result
return op.result+(x[0]-x[0])[0][0][0]
class IModel(nn.Module):
@@ -207,7 +279,7 @@ class IModel(nn.Module):
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
def computation_to_model(computation: TComputation, precal_witness_path:str, isProver:bool ,error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
@@ -216,6 +288,10 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
State is a container for intermediate results of computation, which can be useful when debugging.
"""
state = State(error)
# if it's verifier
state.precal_witness_path= precal_witness_path
state.isProver = isProver
class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
@@ -223,6 +299,12 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
state.set_ready_for_exporting_onnx()
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)
# print('x sy: ')
result = computation(state, x)
if len(result) ==1:
return x[0][0][0][0]-x[0][0][0][0]+torch.tensor(1.0), result
else:
return result
# print('state:: ', state.aggregate_witness_path)
return state, Model

View File

@@ -451,4 +451,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])
return res_poseidon_hash
return res_poseidon_hash

View File

@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod, abstractclassmethod
import statistics
from typing import Optional
import numpy as np
import torch
# boolean: either 1.0 or 0.0
IsResultPrecise = torch.Tensor
MagicNumber = 9999999
MagicNumber = 99.999
class Operation(ABC):
@@ -23,26 +24,29 @@ class Operation(ABC):
...
class Where(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Where':
# here error is trivial, but here to conform to other functions
return cls(torch.where(x[0],x[1], MagicNumber ),error)
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
# print('sellll: ', self.result)
return torch.sum(bool_array.float())==x[1].size()[1]
class Mean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Mean':
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
if precal_witness is None:
# this is prover
# print('provvv')
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
else:
# this is verifier
# print('verrrr')
if op_dict is None:
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
elif 'Mean' not in op_dict:
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['Mean_'+str(op_dict['Mean'])][0]), error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x = torch.where(x==MagicNumber, 0.0, x)
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*self.result*size)
@@ -59,83 +63,113 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
class Median(Operation):
def __init__(self, x: torch.Tensor, error: float):
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ):
if precal_witness is None:
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
super().__init__(torch.tensor(np.median(x_1d)), error)
sorted_x = np.sort(x_1d)
len_x = len(x_1d)
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
super().__init__(torch.tensor(np.median(x_1d)), error)
sorted_x = np.sort(x_1d)
len_x = len(x_1d)
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
elif 'Median' not in op_dict:
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][0]), error)
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][1]), requires_grad=False)
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][2]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Median':
return cls(x[0],error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
old_size = x.size()[1]
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
min_x = torch.min(x)
x = torch.where(x==MagicNumber,min_x-1, x)
# since within 1%, we regard as same value
count_less = torch.sum((x < self.result).float())-(old_size-size)
count_equal = torch.sum((x==self.result).float())
count_less = torch.sum(torch.where(x < self.result, 1.0, 0.0))-(old_size-size)
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
half_size = torch.floor(torch.div(size, 2))
# print('hhhh: ', half_size)
less_cons = count_less<half_size+size%2
more_cons = count_less+count_equal>half_size
# For count_equal == 0
lower_exist = torch.sum((x==self.lower).float())>0
lower_cons = torch.sum((x>self.lower).float())==half_size
upper_exist = torch.sum((x==self.upper).float())>0
upper_cons = torch.sum((x<self.upper).float())==half_size
lower_exist = torch.sum(torch.where(x==self.lower, 1.0, 0.0))>0
lower_cons = torch.sum(torch.where(x>self.lower, 1.0, 0.0))==half_size
upper_exist = torch.sum(torch.where(x==self.upper, 1.0, 0.0))>0
upper_cons = torch.sum(torch.where(x<self.upper, 1.0, 0.0))==half_size
bound = count_less== half_size
# 0.02 since 2*0.01
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
median_in_cons = torch.logical_and(less_cons, more_cons)
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
return torch.where(count_equal==0, median_out_cons, median_in_cons)
return torch.where(count_equal==0.0, median_out_cons, median_in_cons)
class GeometricMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.exp(torch.mean(torch.log(x_1d)))
return cls(result, error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'GeometricMean':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.exp(torch.mean(torch.log(x_1d)))
return cls(result, error)
else:
if op_dict is None:
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
elif 'GeometricMean' not in op_dict:
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['GeometricMean_'+str(op_dict['GeometricMean'])][0]), error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
x = x[0]
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x = torch.where(x==MagicNumber, 1.0, x)
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
class HarmonicMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'HarmonicMean':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
return cls(result, error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'HarmonicMean':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
return cls(result, error)
else:
if op_dict is None:
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
elif 'HarmonicMean' not in op_dict:
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['HarmonicMean_'+str(op_dict['HarmonicMean'])][0]), error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
x = x[0]
size = torch.sum((x!=MagicNumber).float())
# just make it really big so that 1/x goes to zero for element that gets filtered out
x = torch.where(x==MagicNumber, x*x, x)
return torch.abs((self.result*torch.sum(torch.div(1.0, x))) - size)<=torch.abs(self.error*size)
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
return torch.abs((self.result*torch.sum(torch.where(x==MagicNumber, 0.0, torch.div(1.0, x)))) - size)<=torch.abs(self.error*size)
def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
@@ -186,12 +220,21 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
class Mode(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
# Here is traditional definition of Mode, can just put this num_error to be 0
result = torch.tensor(mode_within(x_1d, 0))
return cls(result, error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Mode':
if precal_witness is None:
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
# Here is traditional definition of Mode, can just put this num_error to be 0
result = torch.tensor(mode_within(x_1d, 0))
return cls(result, error)
else:
if op_dict is None:
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
elif 'Mode' not in op_dict:
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
else:
return cls(torch.tensor(precal_witness['Mode_'+str(op_dict['Mode'])][0]), error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
@@ -199,189 +242,281 @@ class Mode(Operation):
min_x = torch.min(x)
old_size = x.size()[1]
x = torch.where(x==MagicNumber, min_x-1, x)
count_equal = torch.sum((x==self.result).float())
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
return torch.sum(result) == old_size
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
count_check = 0
for ele in x[0]:
bool1 = torch.sum(torch.where(x==ele[0], 1.0, 0.0))<=count_equal
bool2 = ele[0]==min_x-1
count_check += torch.logical_or(bool1, bool2)
return count_check ==old_size
class PStdev(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 0))
super().__init__(result, error)
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 0))
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
elif 'PStdev' not in op_dict:
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][1]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'PStdev':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PStdev':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
return torch.logical_and(
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
)
class PVariance(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 0)
super().__init__(result, error)
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 0)
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
elif 'PVariance' not in op_dict:
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][1]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'PVariance':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PVariance':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
return torch.logical_and(
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
)
class Stdev(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 1))
super().__init__(result, error)
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 1))
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
elif 'Stdev' not in op_dict:
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][1]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Stdev':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Stdev':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
return torch.logical_and(
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
)
class Variance(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 1)
super().__init__(result, error)
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 1)
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
elif 'Variance' not in op_dict:
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][0]), error)
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][1]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Variance':
return cls(x[0], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Variance':
return cls(x[0], error, precal_witness, op_dict)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
return torch.logical_and(
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
)
class Covariance(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
super().__init__(result, error)
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
elif 'Covariance' not in op_dict:
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][2]), requires_grad=False)
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Covariance':
return cls(x[0], x[1], error)
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Covariance':
return cls(x[0], x[1], error, precal_witness, op_dict)
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
x, y = args[0], args[1]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
size_x = torch.sum((x!=MagicNumber).float())
size_y = torch.sum((y!=MagicNumber).float())
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
# only x_fil_mean is enough, no need for y_fil_mean since it will multiply 0 anyway
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
return torch.logical_and(
torch.logical_and(size_x==size_y,torch.logical_and(x_mean_cons,y_mean_cons)),
torch.abs(torch.sum((x_fil_mean-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*self.result*(size_x-1)
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*self.result)<=torch.abs(self.error*self.result*(size_x-1))
)
# refer other constraints to correlation function, not put here since will be repetitive
def stdev_for_corr(x_fil_mean:torch.Tensor,size_x:torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
def stdev_for_corr(x_adj_mean:torch.Tensor, size_x:torch.Tensor, x_std: torch.Tensor, error: float) -> torch.Tensor:
return (
torch.abs(torch.sum((x_fil_mean-x_mean)*(x_fil_mean-x_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
, x_std)
# refer other constraints to correlation function, not put here since will be repetitive
def covariance_for_corr(x_fil_mean: torch.Tensor,y_fil_mean: torch.Tensor,size_x:torch.Tensor, size_y:torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
def covariance_for_corr(x_adj_mean: torch.Tensor,y_adj_mean: torch.Tensor,size_x:torch.Tensor, cov: torch.Tensor, error: float) -> torch.Tensor:
return (
torch.abs(torch.sum((x_fil_mean-x_mean)*(y_fil_mean-y_mean))-(size_x-1)*cov)<error*cov*(size_x-1)
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*cov)<=torch.abs(error*cov*(size_x-1))
, cov)
class Correlation(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
super().__init__(result, error)
else:
if op_dict is None:
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
elif 'Correlation' not in op_dict:
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
else:
super().__init__(torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][0]), error)
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][1]), requires_grad=False)
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][2]), requires_grad=False)
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][3]), requires_grad=False)
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][4]), requires_grad=False)
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][5]), requires_grad=False)
super().__init__(result, error)
@classmethod
def create(cls, args: list[torch.Tensor], error: float) -> 'Correlation':
return cls(args[0], args[1], error)
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Correlation':
return cls(args[0], args[1], error, precal_witness, op_dict)
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
x, y = args[0], args[1]
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
size_x = torch.sum((x!=MagicNumber).float())
size_y = torch.sum((y!=MagicNumber).float())
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
y_fil_mean = torch.where(y==MagicNumber, self.y_mean, y)
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
miscel_cons = torch.logical_and(size_x==size_y, torch.logical_and(x_mean_cons, y_mean_cons))
bool1, cov = covariance_for_corr(x_fil_mean,y_fil_mean,size_x, size_y, self.cov, self.x_mean, self.y_mean, self.error)
bool2, x_std = stdev_for_corr( x_fil_mean, size_x, self.x_std, self.x_mean, self.error)
bool3, y_std = stdev_for_corr( y_fil_mean, size_y, self.y_std, self.y_mean, self.error)
bool4 = torch.abs(cov - self.result*x_std*y_std)<=self.error*cov
bool1, cov = covariance_for_corr(x_adj_mean,y_adj_mean,size_x, self.cov, self.error)
bool2, x_std = stdev_for_corr( x_adj_mean, size_x, self.x_std, self.error)
bool3, y_std = stdev_for_corr( y_adj_mean, size_y, self.y_std, self.error)
# this is correlation constraint
bool4 = torch.abs(cov - self.result*x_std*y_std)<=torch.abs(self.error*cov)
return torch.logical_and(torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4)), miscel_cons)
@@ -390,34 +525,53 @@ def stacked_x(args: list[float]):
class Regression(Operation):
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
x_1ds = [to_1d(i) for i in xs]
fil_x_1ds=[]
for x_1 in x_1ds:
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
x_1ds = fil_x_1ds
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
if precal_witness is None:
x_1ds = [to_1d(i) for i in xs]
fil_x_1ds=[]
for x_1 in x_1ds:
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
x_1ds = fil_x_1ds
y_1d = to_1d(y)
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
y_1d = to_1d(y)
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
x_one = stacked_x(x_1ds)
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
print('result: ', result)
super().__init__(result, error)
x_one = stacked_x(x_1ds)
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
# print('result: ', result)
super().__init__(result, error)
else:
if op_dict is None:
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
elif 'Regression' not in op_dict:
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
else:
result = torch.tensor(precal_witness['Regression_'+str(op_dict['Regression'])]).reshape(1,-1,1)
# for ele in precal_witness['Regression']:
# precal_witness_arr.append(torch.tensor(ele))
# print('resultopppp: ', result)
super().__init__(result,error)
@classmethod
def create(cls, args: list[torch.Tensor], error: float) -> 'Regression':
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Regression':
xs = args[:-1]
y = args[-1]
return cls(xs, y, error)
return cls(xs, y, error, precal_witness, op_dict)
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
# infer y from the last parameter
y = args[-1]
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
y = torch.where(y==MagicNumber,0.0, y)
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
x_t = torch.transpose(x_one, 1, 2)
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))
left = x_t @ x_one @ self.result - x_t @ y
right = self.error*x_t @ y
abs_left = torch.where(left>=0, left, -left)
abs_right = torch.where(right>=0, right, -right)
return torch.where(torch.sum(torch.where(abs_left<=abs_right, 1.0, 0.0))==torch.tensor(2.0), 1.0, 0.0)