clipnorm support

This commit is contained in:
George Hotz
2022-09-24 13:26:38 -04:00
parent 271446e3eb
commit acae9a20c1

View File

@@ -10,6 +10,12 @@ class Optimizer:
self.params = [x for x in params if x.requires_grad]
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
def clipnorm(self, amount=1):
for param in self.params:
# clipnorm is the L2 norm, not value: is this right?
param.grad.assign(param.grad.clip(-(amount**2), (amount**2)))
def zero_grad(self):
for param in self.params:
param.grad = None