update readme

This commit is contained in:
JernKunpittaya
2024-03-08 18:39:23 +07:00
parent 516b520bb3
commit 447435896d

View File

@@ -65,26 +65,26 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
Aside from the ZKStats operations, you can also use PyTorch functions like (`torch.abs`, `torch.max`, ...etc).
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit.
TODO: We should have a list for all supported PyTorch functions.
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit. To filter data based on condition, we can use s.where as follows.
#### Data Filtering
Since we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we need to filter data while still preserving the shape. We use condition + `torch.where` instead.
Although we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we implemented State.where operation that allows users to filter data by their own choice of condition as follows.
```python
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
# Compute the mean of the absolute values
x = data[0]
condition = x > 20
# Filter out data that is greater than 20. For the data that is greater than 20, we will use 0.0
fil_X = torch.where(condition=condition, input=x, other=0.0)
return s.mean(fil_X)
# Here condition can be chained as shown below, and can have many variables if we have more than just x: e.g. filter = torch.logical_and(x>20, y<2) in case of regression for example.
filter = torch.logical_and(x > 20, x<50)
# call our where function
filtered_x = s.where(filter, x)
# Then, can use the stats operation as usual
return s.mean(filtered_x)
```
**Caveats**: Currently, this 'where' operation still doesn't work correctly, since we cannot just plug fil_X into our current s.mean() due to incompatible shape of fil_X and X in reality, we will update the compatible implementation of how to do data filtering soon. Keep posted!
### Proof Generation and Verification
The flow between data providers and users is as follows:
@@ -204,6 +204,7 @@ See our jupyter notebook for [examples](./examples/).
## Benchmarks
See our jupyter notebook for [benchmarks](./benchmark/).
TODO: clean benchmark
## Note