feat: remove unused models and add comments

This commit is contained in:
mhchia
2024-01-15 23:40:24 +08:00
parent 44f8db99bc
commit ea2612d939
2 changed files with 8 additions and 49 deletions

View File

@@ -60,11 +60,19 @@ class IModel(nn.Module):
# An computation function. Example:
# def computation(state: State, x: list[torch.Tensor]):
# b_0, out_0 = state.median(x[0])
# b_1, out_1 = state.median(x[1])
# b_2, out_2 = state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
# return torch.logical_and(torch.logical_and(b_0, b_1), b_2), out_2
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]
def create_model(computation: TComputation) -> Type[nn.Module]:
"""
Create a torch model from a `computation` function defined by user
"""
state = State()

View File

@@ -1,49 +0,0 @@
from typing import Any
from abc import ABC, abstractmethod
from torch import nn
import torch
class BaseZKStatsModel(ABC, nn.Module):
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, X: Any) -> Any:
"""
:param X: a tensor of shape (1, n, 1)
:return: a tuple of (bool, float)
"""
class NoDivisionModel(BaseZKStatsModel):
def __init__(self):
super().__init__()
# w represents mean in this case
@abstractmethod
def prepare(expected_output: Any):
...
@abstractmethod
def forward(self, X: Any) -> tuple[float, float]:
# some expression of tolerance to error in the inference
# must have w first!
...
class MeanModel(NoDivisionModel):
def __init__(self):
super().__init__()
def prepare(self, X: Any):
expected_output = torch.mean(X[0])
# w represents mean in this case
self.w = nn.Parameter(data = expected_output, requires_grad = False)
def forward(self, X: Any) -> tuple[float, float]:
# some expression of tolerance to error in the inference
# must have w first!
return (torch.abs(torch.sum(X)-X.size()[1]*(self.w))<0.01*X.size()[1]*(self.w), self.w)
# TODO: Copy the rest of models here