mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
Dedup params in Optimizer (#1047)
* Dedup params in optimizer * Passing the same tensor multiple times in the set of learnable params passed to optimizers can result in models completely failing to learn, but no errors are produced. This dedups tensors to avoid the problem. * Fix types * Use new variable to satisfy linter * Use `helpers.dedup` instead of `set()` to dedup params * Add test for duped params in optimizers
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
@@ -8,8 +9,8 @@ class Optimizer:
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params: List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers: List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
Reference in New Issue
Block a user