mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-10 05:57:55 -05:00
feat: remove unused models and add comments
This commit is contained in:
@@ -60,11 +60,19 @@ class IModel(nn.Module):
|
||||
|
||||
|
||||
|
||||
# An computation function. Example:
|
||||
# def computation(state: State, x: list[torch.Tensor]):
|
||||
# b_0, out_0 = state.median(x[0])
|
||||
# b_1, out_1 = state.median(x[1])
|
||||
# b_2, out_2 = state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
|
||||
# return torch.logical_and(torch.logical_and(b_0, b_1), b_2), out_2
|
||||
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]
|
||||
|
||||
|
||||
|
||||
def create_model(computation: TComputation) -> Type[nn.Module]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
"""
|
||||
state = State()
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class BaseZKStatsModel(ABC, nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, X: Any) -> Any:
|
||||
"""
|
||||
:param X: a tensor of shape (1, n, 1)
|
||||
:return: a tuple of (bool, float)
|
||||
"""
|
||||
|
||||
|
||||
class NoDivisionModel(BaseZKStatsModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# w represents mean in this case
|
||||
|
||||
@abstractmethod
|
||||
def prepare(expected_output: Any):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, X: Any) -> tuple[float, float]:
|
||||
# some expression of tolerance to error in the inference
|
||||
# must have w first!
|
||||
...
|
||||
|
||||
class MeanModel(NoDivisionModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def prepare(self, X: Any):
|
||||
expected_output = torch.mean(X[0])
|
||||
# w represents mean in this case
|
||||
self.w = nn.Parameter(data = expected_output, requires_grad = False)
|
||||
|
||||
def forward(self, X: Any) -> tuple[float, float]:
|
||||
# some expression of tolerance to error in the inference
|
||||
# must have w first!
|
||||
return (torch.abs(torch.sum(X)-X.size()[1]*(self.w))<0.01*X.size()[1]*(self.w), self.w)
|
||||
|
||||
|
||||
# TODO: Copy the rest of models here
|
||||
Reference in New Issue
Block a user