mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
update readme
This commit is contained in:
19
README.md
19
README.md
@@ -65,26 +65,26 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
Aside from the ZKStats operations, you can also use PyTorch functions like (`torch.abs`, `torch.max`, ...etc).
|
||||
|
||||
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit.
|
||||
|
||||
TODO: We should have a list for all supported PyTorch functions.
|
||||
|
||||
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit. To filter data based on condition, we can use s.where as follows.
|
||||
|
||||
#### Data Filtering
|
||||
|
||||
Since we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we need to filter data while still preserving the shape. We use condition + `torch.where` instead.
|
||||
Although we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we implemented State.where operation that allows users to filter data by their own choice of condition as follows.
|
||||
|
||||
```python
|
||||
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
# Compute the mean of the absolute values
|
||||
x = data[0]
|
||||
condition = x > 20
|
||||
# Filter out data that is greater than 20. For the data that is greater than 20, we will use 0.0
|
||||
fil_X = torch.where(condition=condition, input=x, other=0.0)
|
||||
return s.mean(fil_X)
|
||||
# Here condition can be chained as shown below, and can have many variables if we have more than just x: e.g. filter = torch.logical_and(x>20, y<2) in case of regression for example.
|
||||
filter = torch.logical_and(x > 20, x<50)
|
||||
# call our where function
|
||||
filtered_x = s.where(filter, x)
|
||||
# Then, can use the stats operation as usual
|
||||
return s.mean(filtered_x)
|
||||
```
|
||||
|
||||
**Caveats**: Currently, this 'where' operation still doesn't work correctly, since we cannot just plug fil_X into our current s.mean() due to incompatible shape of fil_X and X in reality, we will update the compatible implementation of how to do data filtering soon. Keep posted!
|
||||
|
||||
### Proof Generation and Verification
|
||||
|
||||
The flow between data providers and users is as follows:
|
||||
@@ -204,6 +204,7 @@ See our jupyter notebook for [examples](./examples/).
|
||||
## Benchmarks
|
||||
|
||||
See our jupyter notebook for [benchmarks](./benchmark/).
|
||||
TODO: clean benchmark
|
||||
|
||||
## Note
|
||||
|
||||
|
||||
Reference in New Issue
Block a user