mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
20
README.md
20
README.md
@@ -56,10 +56,10 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
# Compute the median of the second column
|
||||
median2 = s.median(data[1])
|
||||
# Compute the mean of the medians
|
||||
return s.mean(torch.Tensor([median1, median2]).reshape(1, -1, 1))
|
||||
return s.mean(torch.cat((median1.unsqueeze(0), median2.unsqueeze(0))).reshape(1,-1,1))
|
||||
```
|
||||
|
||||
> NOTE: `reshape` is required for now since input must be in shape `[1, data_size, 1]` for now. It should be addressed in the future
|
||||
> NOTE: `reshape` is required for now since input must be in shape `[1, data_size, 1]` for now. It should be addressed in the future, the same for torch.cat(), and unsqueeze(), we will write wrapper in the future.
|
||||
|
||||
#### Torch Operations
|
||||
|
||||
@@ -88,7 +88,7 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
### Proof Generation and Verification
|
||||
|
||||
The flow between data providers and users is as follows:
|
||||

|
||||

|
||||
|
||||
#### Data Provider: generate data commitments
|
||||
|
||||
@@ -109,12 +109,16 @@ When generating a proof, since dataset might contain floating points, data provi
|
||||
|
||||
#### Both: derive PyTorch model from the computation
|
||||
|
||||
When a user wants to request a data provider to generate a proof for their defined computation, the user must send the data provider first. Then, both the data provider and the user transform the model to necessary settings, respectively.
|
||||
When a user wants to request a data provider to generate a proof for their defined computation, the user must let the data provider know what the computation is. Then, the data provider, with real dataset, will generate model from computation using computation_to_model() method. Since we use witness approach (described more in Note section below), the data provider is required to send the pre-calculated witness back to verifier. Then, verifier, with pre-calculated witness, generates the model from computation to be the exact model as prover.
|
||||
|
||||
Note here, that we can also just let prover generate model, and then send that model to verifier directly. However, to make sure that the prover's model actually comes from verifier's computation, it's better to have verifier generates the model itself from its computation, but just with the help of pre-calculated witness.
|
||||
|
||||
```python
|
||||
from zkstats.core import computation_to_model
|
||||
|
||||
_, model = computation_to_model(user_computation)
|
||||
# For prover: generate prover_model, and write to precal_witness file
|
||||
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, error)
|
||||
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
|
||||
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, error)
|
||||
```
|
||||
|
||||
#### Data Provider: generate settings
|
||||
@@ -204,14 +208,14 @@ See our jupyter notebook for [examples](./examples/).
|
||||
|
||||
## Benchmarks
|
||||
|
||||
TOFIX: Update the benchmark method. See more in issues.
|
||||
See our jupyter notebook for [benchmarks](./benchmark/).
|
||||
TODO: clean benchmark
|
||||
|
||||
## Note
|
||||
|
||||
- We implement using witness approach instead of directly calculating the value in circuit. This sometimes allows us to not calculate stuffs like division or exponential which requires larger scale in settings. (If we don't use larger scale in those cases, the accuracy will be very bad)
|
||||
- Dummy data to feed in verifier onnx file needs to have same shape as the private dataset, but can be filled with any value (we just randomize it to be uniform 1-10 with 1 decimal).
|
||||
- For Mode function, if there are more than 1 value possible, we just output one of them (the one that first encountered), conforming to the spec of statistics.mode in python lib (https://docs.python.org/3.9/library/statistics.html#statistics.mode)
|
||||
- For Mode function, if there are more than 1 value possible, we just outputthe one that first encountered, conforming to the spec of statistics.mode in python lib (https://docs.python.org/3.9/library/statistics.html#statistics.mode)
|
||||
|
||||
## Legacy
|
||||
|
||||
|
||||
BIN
assets/zkstats-flow.png
Normal file
BIN
assets/zkstats-flow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 105 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 115 KiB |
4
examples/1.only_torch/data.json
Normal file
4
examples/1.only_torch/data.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6, 7],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
281
examples/1.only_torch/only_torch.ipynb
Normal file
281
examples/1.only_torch/only_torch.ipynb
Normal file
File diff suppressed because one or more lines are too long
4
examples/2.torch+state/data.json
Normal file
4
examples/2.torch+state/data.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
294
examples/2.torch+state/torch+state.ipynb
Normal file
294
examples/2.torch+state/torch+state.ipynb
Normal file
File diff suppressed because one or more lines are too long
4
examples/3.state/data.json
Normal file
4
examples/3.state/data.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"x": [0.5, 1, 2, 3, 4, 5, 6],
|
||||
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
|
||||
}
|
||||
294
examples/3.state/state.ipynb
Normal file
294
examples/3.state/state.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"x": [
|
||||
0, 1, 2, 3, 4
|
||||
],
|
||||
"y": [
|
||||
2.0, 5.2, 47.4, 23.6, 24.8
|
||||
]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,38 +0,0 @@
|
||||
{
|
||||
"col_1": [
|
||||
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
|
||||
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
|
||||
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
|
||||
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
|
||||
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
|
||||
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
|
||||
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
|
||||
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
|
||||
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
|
||||
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
|
||||
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
|
||||
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
|
||||
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
|
||||
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
|
||||
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
|
||||
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
|
||||
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
|
||||
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
|
||||
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
|
||||
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
|
||||
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
|
||||
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
|
||||
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
|
||||
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
|
||||
],
|
||||
"col_2": [
|
||||
19.2, 54.1, 16.5, 24.8, 42.7, 18.9, 78.8, 54.4, 27.4, 76.2, 43.4, 20.9, 2.9,
|
||||
30.4, 21.4, 2.0, 5.6, 33.5, 4.8, 4.7, 57.5, 23.5, 40.1, 83.1, 78.9, 95.1,
|
||||
41.1, 59.0, 59.2, 91.1, 20.9, 67.6, 44.1, 91.3, 89.9, 85.7, 92.6, 67.1,
|
||||
90.0, 29.5, 40.9, 96.8, 2.3, 57.9, 93.2, 83.9, 10.4, 75.1, 24.2, 22.9, 21.2,
|
||||
26.9, 96.8, 89.0, 68.0, 16.1, 90.1, 1.7, 79.6, 98.5, 21.3, 79.5, 9.2, 97.9,
|
||||
21.6, 4.2, 66.1, 53.8, 79.5, 60.6, 66.9, 39.5, 50.1, 66.1, 96.4, 80.5, 61.9,
|
||||
44.4, 84.8, 64.8, 23.2, 7.1, 21.1, 90.5, 29.2, 1.4, 54.8, 9.8, 41.1, 45.2,
|
||||
56.6, 48.2, 61.3, 62.9, 2.7, 33.2, 62.5, 40.9, 33.6, 50.1
|
||||
]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,28 +1,8 @@
|
||||
{
|
||||
"col_name": [
|
||||
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
|
||||
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, 74.7,
|
||||
72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, 61.8, 56.9,
|
||||
44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, 68.4, 35.3,
|
||||
79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, 34.4, 31.4,
|
||||
78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, 7.1, 55.9, 38.6,
|
||||
15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, 4.0, 61.1, 16.8, 81.9,
|
||||
49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, 67.5, 18.7, 10.8, 80.7,
|
||||
80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, 36.2, 7.5, 55.9, 1.2, 56.8,
|
||||
85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, 23.1, 73.2, 86.6, 39.1, 45.5,
|
||||
85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, 14.2, 84.6, 33.7, 86.3, 83.3, 62.8,
|
||||
72.7, 14.7, 36.8, 92.5, 4.7, 30.0, 59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8,
|
||||
47.1, 63.6, 6.0, 96.6, 61.2, 80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2,
|
||||
87.2, 63.2, 81.8, 55.0, 59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1,
|
||||
29.9, 41.0, 85.2, 61.5, 14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8,
|
||||
32.6, 22.2, 19.3, 49.1, 70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4,
|
||||
36.4, 91.5, 69.6, 75.3, 92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5,
|
||||
60.5, 73.5, 70.7, 77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9,
|
||||
89.6, 15.1, 45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9,
|
||||
42.9, 11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
|
||||
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
|
||||
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
|
||||
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, 69.6,
|
||||
22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
|
||||
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
|
||||
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
|
||||
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
|
||||
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
|
||||
]
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
250
examples/where/where+harmomean/where+harmomean.ipynb
Normal file
250
examples/where/where+harmomean/where+harmomean.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,29 +1,8 @@
|
||||
{
|
||||
"col_name": [
|
||||
15.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0, 41.0,
|
||||
47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0, 51.0,
|
||||
48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0, 51.0,
|
||||
56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0, 62.0,
|
||||
39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0, 52.0,
|
||||
45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0, 45.0,
|
||||
45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0, 34.0,
|
||||
66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0, 30.0,
|
||||
33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0, 62.0,
|
||||
38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0, 52.0,
|
||||
41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0, 52.0,
|
||||
58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0, 64.0,
|
||||
46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0, 69.0,
|
||||
27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0, 55.0,
|
||||
58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0, 45.0,
|
||||
40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0, 39.0,
|
||||
58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0, 34.0,
|
||||
50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0, 22.0,
|
||||
50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0, 44.0,
|
||||
55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0, 43.0,
|
||||
50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0, 41.0,
|
||||
41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0, 64.0,
|
||||
76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0, 50.0,
|
||||
31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0, 40.0,
|
||||
25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
|
||||
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
|
||||
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
|
||||
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 44.8, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
|
||||
27.5, 30.9, 62.9, 44.8, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 44.8
|
||||
]
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -24,4 +24,4 @@ def column_2():
|
||||
|
||||
@pytest.fixture
|
||||
def scales():
|
||||
return [6]
|
||||
return [7]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import json
|
||||
from typing import Type, Sequence, Optional
|
||||
from typing import Type, Sequence, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from zkstats.core import prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment
|
||||
from zkstats.computation import IModel
|
||||
from zkstats.core import create_dummy,prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment, verifier_define_calculation
|
||||
from zkstats.computation import IModel, State, computation_to_model
|
||||
|
||||
|
||||
DEFAULT_POSSIBLE_SCALES = list(range(20))
|
||||
@@ -22,17 +22,20 @@ def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, li
|
||||
column: d.tolist()
|
||||
for column, d in zip(column_names, data)
|
||||
}
|
||||
print('columnnnn: ', column_to_data)
|
||||
with open(data_path, "w") as f:
|
||||
json.dump(column_to_data, f)
|
||||
return column_to_data
|
||||
|
||||
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
def compute(
|
||||
basepath: Path,
|
||||
data: list[torch.Tensor],
|
||||
model: Type[IModel],
|
||||
# computation: TComputation,
|
||||
scales_params: Optional[Sequence[int]] = None,
|
||||
selected_columns_params: Optional[list[str]] = None,
|
||||
# error:float = 1.0
|
||||
) -> None:
|
||||
sel_data_path = basepath / "comb_data.json"
|
||||
model_path = basepath / "model.onnx"
|
||||
@@ -60,43 +63,21 @@ def compute(
|
||||
else:
|
||||
scales = scales_params
|
||||
scales_for_commitments = scales_params
|
||||
# create_dummy((data_path), (dummy_data_path))
|
||||
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
|
||||
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, error)
|
||||
|
||||
generate_data_commitment(data_path, scales_for_commitments, data_commitment_path)
|
||||
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
|
||||
|
||||
prover_gen_settings(
|
||||
data_path=data_path,
|
||||
selected_columns=selected_columns,
|
||||
sel_data_path=str(sel_data_path),
|
||||
prover_model=model,
|
||||
prover_model_path=str(model_path),
|
||||
scale=scales,
|
||||
mode="resources",
|
||||
settings_path=str(settings_path),
|
||||
)
|
||||
# No need, since verifier & prover share the same onnx
|
||||
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False,error)
|
||||
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
|
||||
|
||||
setup(
|
||||
str(model_path),
|
||||
str(compiled_model_path),
|
||||
str(settings_path),
|
||||
str(vk_path),
|
||||
str(pk_path),
|
||||
)
|
||||
prover_gen_proof(
|
||||
str(model_path),
|
||||
str(sel_data_path),
|
||||
str(witness_path),
|
||||
str(compiled_model_path),
|
||||
str(settings_path),
|
||||
str(proof_path),
|
||||
str(pk_path),
|
||||
)
|
||||
verifier_verify(
|
||||
str(proof_path),
|
||||
str(settings_path),
|
||||
str(vk_path),
|
||||
selected_columns,
|
||||
data_commitment_path,
|
||||
)
|
||||
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
|
||||
|
||||
prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
|
||||
# print('slett col: ', selected_columns)
|
||||
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
|
||||
|
||||
|
||||
# Error tolerance between zkstats python implementation and python statistics module
|
||||
|
||||
@@ -40,33 +40,35 @@ def nested_computation(state: State, args: list[torch.Tensor]):
|
||||
out_9 = state.correlation(y, z)
|
||||
out_10 = state.linear_regression(x, y)
|
||||
slope, intercept = out_10[0][0][0], out_10[0][1][0]
|
||||
reshaped = torch.tensor([
|
||||
out_0,
|
||||
out_1,
|
||||
out_2,
|
||||
out_3,
|
||||
out_4,
|
||||
out_5,
|
||||
out_6,
|
||||
out_7,
|
||||
out_8,
|
||||
out_9,
|
||||
slope,
|
||||
intercept,
|
||||
]).reshape(1,-1,1)
|
||||
reshaped = torch.cat((
|
||||
out_0.unsqueeze(0),
|
||||
out_1.unsqueeze(0),
|
||||
out_2.unsqueeze(0),
|
||||
out_3.unsqueeze(0),
|
||||
out_4.unsqueeze(0),
|
||||
out_5.unsqueeze(0),
|
||||
out_6.unsqueeze(0),
|
||||
out_7.unsqueeze(0),
|
||||
out_8.unsqueeze(0),
|
||||
out_9.unsqueeze(0),
|
||||
slope.unsqueeze(0),
|
||||
intercept.unsqueeze(0),
|
||||
)).reshape(1,-1,1)
|
||||
out_10 = state.mean(reshaped)
|
||||
return out_10
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[0.1],
|
||||
[ERROR_CIRCUIT_DEFAULT],
|
||||
)
|
||||
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
|
||||
state, model = computation_to_model(nested_computation, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(nested_computation, precal_witness_path,True, error)
|
||||
x, y, z = column_0, column_1, column_2
|
||||
compute(tmp_path, [x, y, z], model, scales)
|
||||
# There are 11 ops in the computation
|
||||
|
||||
assert state.current_op_index == 12
|
||||
|
||||
ops = state.ops
|
||||
@@ -157,8 +159,8 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
|
||||
def where_and_op(state: State, args: list[torch.Tensor]):
|
||||
x = args[0]
|
||||
return op_type(state, state.where(condition(x), x))
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(where_and_op, precal_witness_path,True, error)
|
||||
compute(tmp_path, [column], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
@@ -185,8 +187,8 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
|
||||
filtered_x = state.where(condition_x, x)
|
||||
filtered_y = state.where(condition_x, y)
|
||||
return op_type(state, filtered_x, filtered_y)
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
state, model = computation_to_model(where_and_op, precal_witness_path, True ,error)
|
||||
compute(tmp_path, [column_0, column_1], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
|
||||
@@ -71,14 +71,14 @@ def test_integration_select_partial_columns(tmp_path, column_0, column_1, error,
|
||||
|
||||
def simple_computation(state, x):
|
||||
return state.mean(x[0])
|
||||
|
||||
_, model = computation_to_model(simple_computation, error)
|
||||
precal_witness_path = tmp_path / "precal_witness_path.json"
|
||||
_, model = computation_to_model(simple_computation,precal_witness_path, True, error)
|
||||
# gen settings, setup, prove, verify
|
||||
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
|
||||
|
||||
|
||||
def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
data_json_path = tmp_path / "data.csv"
|
||||
data_json_path = tmp_path / "data.json"
|
||||
data_csv_path = tmp_path / "data.csv"
|
||||
data_json = data_to_json_file(data_json_path, [column_0, column_1])
|
||||
json_file_to_csv(data_json_path, data_csv_path)
|
||||
@@ -92,12 +92,13 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
model_path = tmp_path / "model.onnx"
|
||||
settings_path = tmp_path / "settings.json"
|
||||
data_commitment_path = tmp_path / "commitments.json"
|
||||
precal_witness_path = tmp_path / "precal_witness.json"
|
||||
|
||||
# Test: `generate_data_commitment` works with csv
|
||||
generate_data_commitment(data_csv_path, scales, data_commitment_path)
|
||||
|
||||
# Test: `prover_gen_settings` works with csv
|
||||
_, model_for_proving = computation_to_model(simple_computation, error)
|
||||
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True,error)
|
||||
prover_gen_settings(
|
||||
data_path=data_csv_path,
|
||||
selected_columns=selected_columns,
|
||||
@@ -111,10 +112,9 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
|
||||
|
||||
# Test: `prover_gen_settings` works with csv
|
||||
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
|
||||
_, model_for_verification = computation_to_model(simple_computation, error)
|
||||
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False,error)
|
||||
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
|
||||
|
||||
|
||||
def json_file_to_csv(data_json_path, data_csv_path):
|
||||
with open(data_json_path, "r") as f:
|
||||
data_from_json = json.load(f)
|
||||
@@ -154,4 +154,4 @@ def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
|
||||
_preprocess_data_file_to_json(data_json_path, new_data_json_path)
|
||||
with open(new_data_json_path, "r") as f:
|
||||
new_data_from_json = json.load(f)
|
||||
assert new_data_from_json == data_from_json
|
||||
assert new_data_from_json == data_from_json
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
import torch
|
||||
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
|
||||
from zkstats.computation import IModel, IsResultPrecise
|
||||
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
|
||||
|
||||
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
@@ -44,7 +44,7 @@ def test_ops_2_parameters(tmp_path, column_0: torch.Tensor, column_1: torch.Tens
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
ERROR_CIRCUIT_RELAXED
|
||||
ERROR_CIRCUIT_DEFAULT
|
||||
]
|
||||
)
|
||||
def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float, scales: list[float]):
|
||||
@@ -61,6 +61,7 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
|
||||
compute(tmp_path, columns, Model, scales)
|
||||
|
||||
|
||||
|
||||
def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[list[float]], float], error: float, scales: list[float], columns: list[torch.Tensor]):
|
||||
op = op_type.create(columns, error)
|
||||
expected_res = expected_func(*[column.tolist() for column in columns])
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Callable, Type, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import json
|
||||
|
||||
from .ops import (
|
||||
Operation,
|
||||
@@ -18,12 +19,12 @@ from .ops import (
|
||||
Covariance,
|
||||
Correlation,
|
||||
Regression,
|
||||
Where,
|
||||
IsResultPrecise,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_ERROR = 0.01
|
||||
MagicNumber = 99.999
|
||||
|
||||
|
||||
class State:
|
||||
@@ -43,6 +44,10 @@ class State:
|
||||
self.error: float = error
|
||||
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
|
||||
self.current_op_index: Optional[int] = None
|
||||
self.precal_witness_path: str = None
|
||||
self.precal_witness:dict = {}
|
||||
self.isProver:bool = None
|
||||
self.op_dict:dict={}
|
||||
|
||||
def set_ready_for_exporting_onnx(self) -> None:
|
||||
self.current_op_index = 0
|
||||
@@ -133,19 +138,82 @@ class State:
|
||||
return self._call_op([x, y], Regression)
|
||||
|
||||
# WHERE operation
|
||||
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
|
||||
|
||||
:param filter: A boolean tensor serves as a filter
|
||||
:param _filter: A boolean tensor serves as a filter
|
||||
:param x: A tensor to be filtered
|
||||
:return: filtered tensor
|
||||
"""
|
||||
return self._call_op([filter, x], Where)
|
||||
return torch.where(_filter, x, x-x+MagicNumber)
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
op = op_type.create(x, self.error)
|
||||
# for prover
|
||||
if self.isProver:
|
||||
# print('Prover side create')
|
||||
op = op_type.create(x, self.error)
|
||||
|
||||
# Single witness aka result
|
||||
if isinstance(op,Mean) or isinstance(op,GeometricMean) or isinstance(op, HarmonicMean) or isinstance(op, Mode):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
|
||||
self.op_dict[op_class_str]+=1
|
||||
elif isinstance(op, Median):
|
||||
if 'Median' not in self.op_dict:
|
||||
self.precal_witness['Median_0'] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
self.op_dict['Median']=1
|
||||
else:
|
||||
self.precal_witness['Median_'+str(self.op_dict['Median'])] = [op.result.data.item(), op.lower.data.item(), op.upper.data.item()]
|
||||
self.op_dict['Median']+=1
|
||||
# std + variance stuffs
|
||||
elif isinstance(op, PStdev) or isinstance(op, PVariance) or isinstance(op, Stdev) or isinstance(op, Variance):
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
|
||||
self.op_dict[op_class_str]+=1
|
||||
elif isinstance(op, Covariance):
|
||||
if 'Covariance' not in self.op_dict:
|
||||
self.precal_witness['Covariance_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
||||
self.op_dict['Covariance']=1
|
||||
else:
|
||||
self.precal_witness['Covariance_'+str(self.op_dict['Covariance'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item()]
|
||||
self.op_dict['Covariance']+=1
|
||||
elif isinstance(op, Correlation):
|
||||
if 'Correlation' not in self.op_dict:
|
||||
self.precal_witness['Correlation_0'] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
||||
self.op_dict['Correlation']=1
|
||||
else:
|
||||
self.precal_witness['Correlation_'+str(self.op_dict['Correlation'])] = [op.result.data.item(), op.x_mean.data.item(), op.y_mean.data.item(), op.x_std.data.item(), op.y_std.data.item(), op.cov.data.item()]
|
||||
self.op_dict['Correlation']+=1
|
||||
elif isinstance(op, Regression):
|
||||
result_array = []
|
||||
for ele in op.result.data[0]:
|
||||
result_array.append(ele[0].item())
|
||||
if 'Regression' not in self.op_dict:
|
||||
self.precal_witness['Regression_0'] = [result_array]
|
||||
self.op_dict['Regression']=1
|
||||
else:
|
||||
self.precal_witness['Regression_'+str(self.op_dict['Regression'])] = [result_array]
|
||||
self.op_dict['Regression']+=1
|
||||
# for verifier
|
||||
else:
|
||||
# print('Verifier side create')
|
||||
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
|
||||
op = op_type.create(x, self.error, precal_witness, self.op_dict)
|
||||
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
|
||||
if op_class_str not in self.op_dict:
|
||||
self.op_dict[op_class_str] = 1
|
||||
else:
|
||||
self.op_dict[op_class_str]+=1
|
||||
self.ops.append(op)
|
||||
return op.result
|
||||
else:
|
||||
@@ -171,7 +239,9 @@ class State:
|
||||
|
||||
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
|
||||
# else, return only result
|
||||
|
||||
if current_op_index == len_ops - 1:
|
||||
# print('final op: ', op)
|
||||
# Sanity check for length of self.ops and self.bools
|
||||
len_bools = len(self.bools)
|
||||
if len_ops != len_bools:
|
||||
@@ -180,13 +250,15 @@ class State:
|
||||
for i in range(len_bools):
|
||||
res = self.bools[i]()
|
||||
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
|
||||
return is_precise_aggregated, op.result
|
||||
if self.isProver:
|
||||
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
|
||||
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
elif current_op_index > len_ops - 1:
|
||||
# Sanity check that current op index does not exceed the length of ops
|
||||
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
|
||||
else:
|
||||
# It's not the last operation, just return the result
|
||||
return op.result
|
||||
return op.result+(x[0]-x[0])[0][0][0]
|
||||
|
||||
|
||||
class IModel(nn.Module):
|
||||
@@ -207,7 +279,7 @@ class IModel(nn.Module):
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
|
||||
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
|
||||
def computation_to_model(computation: TComputation, precal_witness_path:str, isProver:bool ,error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
|
||||
@@ -216,6 +288,10 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
State is a container for intermediate results of computation, which can be useful when debugging.
|
||||
"""
|
||||
state = State(error)
|
||||
# if it's verifier
|
||||
|
||||
state.precal_witness_path= precal_witness_path
|
||||
state.isProver = isProver
|
||||
|
||||
class Model(IModel):
|
||||
def preprocess(self, x: list[torch.Tensor]) -> None:
|
||||
@@ -223,6 +299,12 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
state.set_ready_for_exporting_onnx()
|
||||
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return computation(state, x)
|
||||
# print('x sy: ')
|
||||
result = computation(state, x)
|
||||
if len(result) ==1:
|
||||
return x[0][0][0][0]-x[0][0][0][0]+torch.tensor(1.0), result
|
||||
else:
|
||||
return result
|
||||
# print('state:: ', state.aggregate_witness_path)
|
||||
return state, Model
|
||||
|
||||
|
||||
@@ -451,4 +451,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
|
||||
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
|
||||
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])
|
||||
|
||||
return res_poseidon_hash
|
||||
return res_poseidon_hash
|
||||
482
zkstats/ops.py
482
zkstats/ops.py
@@ -1,12 +1,13 @@
|
||||
from abc import ABC, abstractmethod, abstractclassmethod
|
||||
import statistics
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# boolean: either 1.0 or 0.0
|
||||
IsResultPrecise = torch.Tensor
|
||||
MagicNumber = 9999999
|
||||
MagicNumber = 99.999
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
@@ -23,26 +24,29 @@ class Operation(ABC):
|
||||
...
|
||||
|
||||
|
||||
class Where(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Where':
|
||||
# here error is trivial, but here to conform to other functions
|
||||
return cls(torch.where(x[0],x[1], MagicNumber ),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
# print('sellll: ', self.result)
|
||||
return torch.sum(bool_array.float())==x[1].size()[1]
|
||||
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Mean':
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
if precal_witness is None:
|
||||
# this is prover
|
||||
# print('provvv')
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
else:
|
||||
# this is verifier
|
||||
# print('verrrr')
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
|
||||
elif 'Mean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mean_'+str(op_dict['Mean'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x = torch.where(x==MagicNumber, 0.0, x)
|
||||
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*self.result*size)
|
||||
|
||||
@@ -59,83 +63,113 @@ def to_1d(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class Median(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ):
|
||||
if precal_witness is None:
|
||||
# NOTE: To ensure `lower` and `upper` are a scalar, `x` must be a 1d array.
|
||||
# Otherwise, if `x` is a 3d array, `lower` and `upper` will be 2d array, which are not what
|
||||
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
|
||||
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)-1], dtype = torch.float32), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(sorted_x[int(len_x/2)], dtype = torch.float32), requires_grad=False)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
|
||||
elif 'Median' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Median_0'][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_0'][2]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][0]), error)
|
||||
self.lower = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][1]), requires_grad=False)
|
||||
self.upper = torch.nn.Parameter(data = torch.tensor(precal_witness['Median_'+str(op_dict['Median'])][2]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Median':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None ) -> 'Median':
|
||||
return cls(x[0],error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
old_size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
min_x = torch.min(x)
|
||||
x = torch.where(x==MagicNumber,min_x-1, x)
|
||||
|
||||
# since within 1%, we regard as same value
|
||||
count_less = torch.sum((x < self.result).float())-(old_size-size)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
count_less = torch.sum(torch.where(x < self.result, 1.0, 0.0))-(old_size-size)
|
||||
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
|
||||
half_size = torch.floor(torch.div(size, 2))
|
||||
|
||||
# print('hhhh: ', half_size)
|
||||
less_cons = count_less<half_size+size%2
|
||||
more_cons = count_less+count_equal>half_size
|
||||
|
||||
# For count_equal == 0
|
||||
lower_exist = torch.sum((x==self.lower).float())>0
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_size
|
||||
upper_exist = torch.sum((x==self.upper).float())>0
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_size
|
||||
lower_exist = torch.sum(torch.where(x==self.lower, 1.0, 0.0))>0
|
||||
lower_cons = torch.sum(torch.where(x>self.lower, 1.0, 0.0))==half_size
|
||||
upper_exist = torch.sum(torch.where(x==self.upper, 1.0, 0.0))>0
|
||||
upper_cons = torch.sum(torch.where(x<self.upper, 1.0, 0.0))==half_size
|
||||
bound = count_less== half_size
|
||||
# 0.02 since 2*0.01
|
||||
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
|
||||
|
||||
median_in_cons = torch.logical_and(less_cons, more_cons)
|
||||
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
|
||||
return torch.where(count_equal==0, median_out_cons, median_in_cons)
|
||||
return torch.where(count_equal==0.0, median_out_cons, median_in_cons)
|
||||
|
||||
|
||||
class GeometricMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'GeometricMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
|
||||
elif 'GeometricMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['GeometricMean_'+str(op_dict['GeometricMean'])][0]), error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x = torch.where(x==MagicNumber, 1.0, x)
|
||||
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
|
||||
|
||||
|
||||
class HarmonicMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'HarmonicMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'HarmonicMean':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
|
||||
elif 'HarmonicMean' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['HarmonicMean_'+str(op_dict['HarmonicMean'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
# just make it really big so that 1/x goes to zero for element that gets filtered out
|
||||
x = torch.where(x==MagicNumber, x*x, x)
|
||||
return torch.abs((self.result*torch.sum(torch.div(1.0, x))) - size)<=torch.abs(self.error*size)
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
return torch.abs((self.result*torch.sum(torch.where(x==MagicNumber, 0.0, torch.div(1.0, x)))) - size)<=torch.abs(self.error*size)
|
||||
|
||||
|
||||
def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
@@ -186,12 +220,21 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
|
||||
class Mode(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Mode':
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
|
||||
elif 'Mode' not in op_dict:
|
||||
return cls(torch.tensor(precal_witness['Mode_0'][0]), error)
|
||||
else:
|
||||
return cls(torch.tensor(precal_witness['Mode_'+str(op_dict['Mode'])][0]), error)
|
||||
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
@@ -199,189 +242,281 @@ class Mode(Operation):
|
||||
min_x = torch.min(x)
|
||||
old_size = x.size()[1]
|
||||
x = torch.where(x==MagicNumber, min_x-1, x)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
|
||||
return torch.sum(result) == old_size
|
||||
count_equal = torch.sum(torch.where(x==self.result, 1.0, 0.0))
|
||||
|
||||
count_check = 0
|
||||
for ele in x[0]:
|
||||
bool1 = torch.sum(torch.where(x==ele[0], 1.0, 0.0))<=count_equal
|
||||
bool2 = ele[0]==min_x-1
|
||||
count_check += torch.logical_or(bool1, bool2)
|
||||
return count_check ==old_size
|
||||
|
||||
|
||||
class PStdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
|
||||
elif 'PStdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PStdev_'+str(op_dict['PStdev'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'PStdev':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PStdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class PVariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
|
||||
elif 'PVariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['PVariance_'+str(op_dict['PVariance'])][1]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'PVariance':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'PVariance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Stdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
|
||||
elif 'Stdev' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Stdev_'+str(op_dict['Stdev'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Stdev':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Stdev':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class Variance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
def __init__(self, x: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
|
||||
elif 'Variance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_0'][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_0'][1]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][0]), error)
|
||||
self.data_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Variance_'+str(op_dict['Variance'])][1]), requires_grad=False)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Variance':
|
||||
return cls(x[0], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Variance':
|
||||
return cls(x[0], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
size = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Covariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
|
||||
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
self.x_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(x_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.tensor(statistics.mean(y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
|
||||
elif 'Covariance' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_0'][2]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Covariance_'+str(op_dict['Covariance'])][2]), requires_grad=False)
|
||||
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Covariance':
|
||||
return cls(x[0], x[1], error)
|
||||
def create(cls, x: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Covariance':
|
||||
return cls(x[0], x[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
# only x_fil_mean is enough, no need for y_fil_mean since it will multiply 0 anyway
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
|
||||
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
|
||||
|
||||
return torch.logical_and(
|
||||
torch.logical_and(size_x==size_y,torch.logical_and(x_mean_cons,y_mean_cons)),
|
||||
torch.abs(torch.sum((x_fil_mean-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*self.result*(size_x-1)
|
||||
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*self.result)<=torch.abs(self.error*self.result*(size_x-1))
|
||||
)
|
||||
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def stdev_for_corr(x_fil_mean:torch.Tensor,size_x:torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
def stdev_for_corr(x_adj_mean:torch.Tensor, size_x:torch.Tensor, x_std: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(x_fil_mean-x_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
|
||||
torch.abs(torch.sum((x_adj_mean)*(x_adj_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
|
||||
, x_std)
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def covariance_for_corr(x_fil_mean: torch.Tensor,y_fil_mean: torch.Tensor,size_x:torch.Tensor, size_y:torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
def covariance_for_corr(x_adj_mean: torch.Tensor,y_adj_mean: torch.Tensor,size_x:torch.Tensor, cov: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(y_fil_mean-y_mean))-(size_x-1)*cov)<error*cov*(size_x-1)
|
||||
torch.abs(torch.sum((x_adj_mean)*(y_adj_mean))-(size_x-1)*cov)<=torch.abs(error*cov*(size_x-1))
|
||||
, cov)
|
||||
|
||||
|
||||
class Correlation(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
|
||||
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
|
||||
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data=torch.mean(y_1d), requires_grad = False)
|
||||
self.x_std = torch.nn.Parameter(data=torch.sqrt(torch.var(x_1d, correction = 1)), requires_grad = False)
|
||||
self.y_std = torch.nn.Parameter(data=torch.sqrt(torch.var(y_1d, correction = 1)), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data=torch.tensor(statistics.covariance(x_1d_list, y_1d_list), dtype = torch.float32), requires_grad=False)
|
||||
result = torch.tensor(statistics.correlation(x_1d_list, y_1d_list), dtype = torch.float32)
|
||||
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
|
||||
elif 'Correlation' not in op_dict:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_0'][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_0'][5]), requires_grad=False)
|
||||
else:
|
||||
super().__init__(torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][0]), error)
|
||||
self.x_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][1]), requires_grad=False)
|
||||
self.y_mean = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][2]), requires_grad=False)
|
||||
self.x_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][3]), requires_grad=False)
|
||||
self.y_std = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][4]), requires_grad=False)
|
||||
self.cov = torch.nn.Parameter(data = torch.tensor(precal_witness['Correlation_'+str(op_dict['Correlation'])][5]), requires_grad=False)
|
||||
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float) -> 'Correlation':
|
||||
return cls(args[0], args[1], error)
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Correlation':
|
||||
return cls(args[0], args[1], error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
size_x = torch.sum(torch.where(x!=MagicNumber, 1.0, 0.0))
|
||||
size_y = torch.sum(torch.where(y!=MagicNumber, 1.0, 0.0))
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
y_fil_mean = torch.where(y==MagicNumber, self.y_mean, y)
|
||||
x_adj_mean = torch.where(x==MagicNumber, 0.0, x-self.x_mean)
|
||||
y_adj_mean = torch.where(y==MagicNumber, 0.0, y-self.y_mean)
|
||||
|
||||
miscel_cons = torch.logical_and(size_x==size_y, torch.logical_and(x_mean_cons, y_mean_cons))
|
||||
bool1, cov = covariance_for_corr(x_fil_mean,y_fil_mean,size_x, size_y, self.cov, self.x_mean, self.y_mean, self.error)
|
||||
bool2, x_std = stdev_for_corr( x_fil_mean, size_x, self.x_std, self.x_mean, self.error)
|
||||
bool3, y_std = stdev_for_corr( y_fil_mean, size_y, self.y_std, self.y_mean, self.error)
|
||||
bool4 = torch.abs(cov - self.result*x_std*y_std)<=self.error*cov
|
||||
bool1, cov = covariance_for_corr(x_adj_mean,y_adj_mean,size_x, self.cov, self.error)
|
||||
bool2, x_std = stdev_for_corr( x_adj_mean, size_x, self.x_std, self.error)
|
||||
bool3, y_std = stdev_for_corr( y_adj_mean, size_y, self.y_std, self.error)
|
||||
# this is correlation constraint
|
||||
bool4 = torch.abs(cov - self.result*x_std*y_std)<=torch.abs(self.error*cov)
|
||||
return torch.logical_and(torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4)), miscel_cons)
|
||||
|
||||
|
||||
@@ -390,34 +525,53 @@ def stacked_x(args: list[float]):
|
||||
|
||||
|
||||
class Regression(Operation):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None):
|
||||
if precal_witness is None:
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
# print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
else:
|
||||
if op_dict is None:
|
||||
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
|
||||
elif 'Regression' not in op_dict:
|
||||
result = torch.tensor(precal_witness['Regression_0']).reshape(1,-1,1)
|
||||
else:
|
||||
result = torch.tensor(precal_witness['Regression_'+str(op_dict['Regression'])]).reshape(1,-1,1)
|
||||
|
||||
# for ele in precal_witness['Regression']:
|
||||
# precal_witness_arr.append(torch.tensor(ele))
|
||||
# print('resultopppp: ', result)
|
||||
super().__init__(result,error)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, args: list[torch.Tensor], error: float) -> 'Regression':
|
||||
def create(cls, args: list[torch.Tensor], error: float, precal_witness:Optional[dict] = None, op_dict:Optional[dict[str,int]] = None) -> 'Regression':
|
||||
xs = args[:-1]
|
||||
y = args[-1]
|
||||
return cls(xs, y, error)
|
||||
return cls(xs, y, error, precal_witness, op_dict)
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# infer y from the last parameter
|
||||
y = args[-1]
|
||||
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
|
||||
y = torch.where(y==MagicNumber,0.0, y)
|
||||
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
|
||||
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
|
||||
x_t = torch.transpose(x_one, 1, 2)
|
||||
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))
|
||||
|
||||
left = x_t @ x_one @ self.result - x_t @ y
|
||||
right = self.error*x_t @ y
|
||||
abs_left = torch.where(left>=0, left, -left)
|
||||
abs_right = torch.where(right>=0, right, -right)
|
||||
return torch.where(torch.sum(torch.where(abs_left<=abs_right, 1.0, 0.0))==torch.tensor(2.0), 1.0, 0.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user