fix wrong type and use correct types

This commit is contained in:
mhchia
2024-01-16 14:18:17 +08:00
parent ea2612d939
commit dcf8576edc

View File

@@ -55,7 +55,7 @@ class IModel(nn.Module):
...
@abstractmethod
def forward(self, *x: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
...
@@ -70,18 +70,18 @@ TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torc
def create_model(computation: TComputation) -> Type[nn.Module]:
def create_model(computation: TComputation) -> Type[IModel]:
"""
Create a torch model from a `computation` function defined by user
"""
state = State()
class Model(nn.Module):
class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
computation(state, x)
state.set_ready_for_exporting_onnx()
def forward(self, *x: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)
return Model