mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Refactor nn.optim (#1091)
* Refactor: nn.optim.py * Refactor: nn.optim.py; Fix all tests * Refactor: Replace all optim.get_parameters() * Refactor: Revert list comp. * Refactor: Replace optim.get_state_dict * Refactor: Change quickstart.md
This commit is contained in:
@@ -3,6 +3,7 @@ import numpy as np
|
||||
from tqdm import trange
|
||||
import torch
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
@@ -84,8 +85,8 @@ if __name__ == "__main__":
|
||||
output_dir = Path(".").resolve() / "outputs"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
# optimizers
|
||||
optim_g = optim.Adam(optim.get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(optim.get_parameters(discriminator),lr=0.0002, b1=0.5)
|
||||
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
|
||||
Reference in New Issue
Block a user