Refactor nn.optim (#1091)

* Refactor: nn.optim.py

* Refactor: nn.optim.py; Fix all tests

* Refactor: Replace all optim.get_parameters()

* Refactor: Revert list comp.

* Refactor: Replace optim.get_state_dict

* Refactor: Change quickstart.md
This commit is contained in:
Reza Rezvan
2023-07-03 00:07:30 +02:00
committed by GitHub
parent 10f1aeb144
commit 8ae9a054ae
15 changed files with 44 additions and 36 deletions

View File

@@ -1,6 +1,7 @@
from typing import Optional, Tuple
from numpy.typing import NDArray
from tinygrad.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.helpers import getenv
@@ -152,10 +153,10 @@ class DeepDeterministicPolicyGradient:
self.target_actor = Actor(self.num_actions, self.num_states, hidden_size)
self.target_critic = Critic(self.num_actions + self.num_states, hidden_size)
actor_params = optim.get_parameters(self.actor)
critic_params = optim.get_parameters(self.critic)
target_actor_params = optim.get_parameters(self.target_actor)
target_critic_params = optim.get_parameters(self.target_critic)
actor_params = get_parameters(self.actor)
critic_params = get_parameters(self.critic)
target_actor_params = get_parameters(self.target_actor)
target_critic_params = get_parameters(self.target_critic)
if DEVICE == "GPU":
[x.gpu_() for x in actor_params + critic_params + target_actor_params + target_critic_params]
@@ -171,12 +172,12 @@ class DeepDeterministicPolicyGradient:
tau = self.tau
for param, target_param in zip(
optim.get_parameters(self.actor), optim.get_parameters(self.target_actor)
get_parameters(self.actor), get_parameters(self.target_actor)
):
target_param.assign(param.detach() * tau + target_param * (1.0 - tau))
for param, target_param in zip(
optim.get_parameters(self.critic), optim.get_parameters(self.target_critic)
get_parameters(self.critic), get_parameters(self.target_critic)
):
target_param.assign(param.detach() * tau + target_param * (1.0 - tau))