Refactor nn.optim (#1091)

* Refactor: nn.optim.py

* Refactor: nn.optim.py; Fix all tests

* Refactor: Replace all optim.get_parameters()

* Refactor: Revert list comp.

* Refactor: Replace optim.get_state_dict

* Refactor: Change quickstart.md
This commit is contained in:
Reza Rezvan
2023-07-03 00:07:30 +02:00
committed by GitHub
parent 10f1aeb144
commit 8ae9a054ae
15 changed files with 44 additions and 36 deletions

View File

@@ -3,6 +3,7 @@ import numpy as np
from tqdm import trange
import torch
from torchvision.utils import make_grid, save_image
from tinygrad.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.nn import optim
@@ -84,8 +85,8 @@ if __name__ == "__main__":
output_dir = Path(".").resolve() / "outputs"
output_dir.mkdir(exist_ok=True)
# optimizers
optim_g = optim.Adam(optim.get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(optim.get_parameters(discriminator),lr=0.0002, b1=0.5)
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
# training loop
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0