remove Tensor.no_grad, it's meaningless now [pr] (#10556)

This commit is contained in:
George Hotz
2025-05-28 22:20:02 -07:00
committed by GitHub
parent e4e7b5d7e1
commit b3b43a82c4
35 changed files with 17 additions and 80 deletions

View File

@@ -267,13 +267,10 @@ def train_cifar():
@TinyJit
def update(self, net, decay):
# TODO with Tensor.no_grad()
Tensor.no_grad = True
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
# batchnorm currently is not being tracked
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
Tensor.no_grad = False
set_seed(getenv('SEED', hyp['seed']))