mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 14:45:35 -05:00
Refactor nn.optim (#1091)
* Refactor: nn.optim.py * Refactor: nn.optim.py; Fix all tests * Refactor: Replace all optim.get_parameters() * Refactor: Revert list comp. * Refactor: Replace optim.get_state_dict * Refactor: Change quickstart.md
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Tuple
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
@@ -152,10 +153,10 @@ class DeepDeterministicPolicyGradient:
|
||||
self.target_actor = Actor(self.num_actions, self.num_states, hidden_size)
|
||||
self.target_critic = Critic(self.num_actions + self.num_states, hidden_size)
|
||||
|
||||
actor_params = optim.get_parameters(self.actor)
|
||||
critic_params = optim.get_parameters(self.critic)
|
||||
target_actor_params = optim.get_parameters(self.target_actor)
|
||||
target_critic_params = optim.get_parameters(self.target_critic)
|
||||
actor_params = get_parameters(self.actor)
|
||||
critic_params = get_parameters(self.critic)
|
||||
target_actor_params = get_parameters(self.target_actor)
|
||||
target_critic_params = get_parameters(self.target_critic)
|
||||
|
||||
if DEVICE == "GPU":
|
||||
[x.gpu_() for x in actor_params + critic_params + target_actor_params + target_critic_params]
|
||||
@@ -171,12 +172,12 @@ class DeepDeterministicPolicyGradient:
|
||||
tau = self.tau
|
||||
|
||||
for param, target_param in zip(
|
||||
optim.get_parameters(self.actor), optim.get_parameters(self.target_actor)
|
||||
get_parameters(self.actor), get_parameters(self.target_actor)
|
||||
):
|
||||
target_param.assign(param.detach() * tau + target_param * (1.0 - tau))
|
||||
|
||||
for param, target_param in zip(
|
||||
optim.get_parameters(self.critic), optim.get_parameters(self.target_critic)
|
||||
get_parameters(self.critic), get_parameters(self.target_critic)
|
||||
):
|
||||
target_param.assign(param.detach() * tau + target_param * (1.0 - tau))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user