Improve type hints for optimizer (#3583)

* Improve type hints for optimizer

* lint fix
This commit is contained in:
Szymon Ożóg
2024-03-02 16:35:44 +01:00
committed by GitHub
parent 83530a585f
commit 6c36264790

View File

@@ -22,6 +22,8 @@ class Optimizer:
# NOTE: in extra is too late for most of the params due to issues with assign
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
def step(self) -> None: raise NotImplementedError
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
super().__init__(params, lr)