mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-10 05:57:55 -05:00
fix wrong type and use correct types
This commit is contained in:
@@ -55,7 +55,7 @@ class IModel(nn.Module):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *x: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
@@ -70,18 +70,18 @@ TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torc
|
||||
|
||||
|
||||
|
||||
def create_model(computation: TComputation) -> Type[nn.Module]:
|
||||
def create_model(computation: TComputation) -> Type[IModel]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
"""
|
||||
state = State()
|
||||
|
||||
class Model(nn.Module):
|
||||
class Model(IModel):
|
||||
def preprocess(self, x: list[torch.Tensor]) -> None:
|
||||
computation(state, x)
|
||||
state.set_ready_for_exporting_onnx()
|
||||
|
||||
def forward(self, *x: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return computation(state, x)
|
||||
|
||||
return Model
|
||||
|
||||
Reference in New Issue
Block a user