mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove Tensor.no_grad, it's meaningless now [pr] (#10556)
This commit is contained in:
@@ -267,13 +267,10 @@ def train_cifar():
|
||||
|
||||
@TinyJit
|
||||
def update(self, net, decay):
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
|
||||
# batchnorm currently is not being tracked
|
||||
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
|
||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
||||
Tensor.no_grad = False
|
||||
|
||||
set_seed(getenv('SEED', hyp['seed']))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user