decay LR, little bugfix

This commit is contained in:
George Hotz
2023-02-11 17:34:15 -08:00
parent ba3bf5bdf7
commit dd7accb9cc
2 changed files with 5 additions and 3 deletions

View File

@@ -59,7 +59,9 @@ def train_step_jitted(model, optimizer, X, Y):
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
optimizer.lr *= 0.995 # decay LR
optimizer.lr.realize()
return loss.realize()
def fetch_batch(X_train, Y_train, BS):
# fetch a batch
@@ -88,7 +90,7 @@ def train_cifar():
optimizer = optim.Adam(get_parameters(model), lr=3e-4)
else:
#optimizer = optim.SGD(get_parameters(model), lr=0.001)
optimizer = optim.SGD(get_parameters(model), lr=0.003, momentum=0.85, nesterov=True)
optimizer = optim.SGD(get_parameters(model), lr=Tensor([0.003]).realize(), momentum=0.85, nesterov=True)
# 97 steps in 2 seconds = 20ms / step
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136

View File

@@ -74,7 +74,7 @@ class CLASTKernel(ASTKernel):
const = None
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
assert self.buftokens[buf_index].typ == Types.FLOAT
self.bufs_to_delete.add(buf_index)
if buf_index != 0: self.bufs_to_delete.add(buf_index)
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
tokens = []