mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
decay LR, little bugfix
This commit is contained in:
@@ -59,7 +59,9 @@ def train_step_jitted(model, optimizer, X, Y):
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
optimizer.lr *= 0.995 # decay LR
|
||||
optimizer.lr.realize()
|
||||
return loss.realize()
|
||||
|
||||
def fetch_batch(X_train, Y_train, BS):
|
||||
# fetch a batch
|
||||
@@ -88,7 +90,7 @@ def train_cifar():
|
||||
optimizer = optim.Adam(get_parameters(model), lr=3e-4)
|
||||
else:
|
||||
#optimizer = optim.SGD(get_parameters(model), lr=0.001)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.003, momentum=0.85, nesterov=True)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=Tensor([0.003]).realize(), momentum=0.85, nesterov=True)
|
||||
|
||||
# 97 steps in 2 seconds = 20ms / step
|
||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||
|
||||
@@ -74,7 +74,7 @@ class CLASTKernel(ASTKernel):
|
||||
const = None
|
||||
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
|
||||
assert self.buftokens[buf_index].typ == Types.FLOAT
|
||||
self.bufs_to_delete.add(buf_index)
|
||||
if buf_index != 0: self.bufs_to_delete.add(buf_index)
|
||||
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
|
||||
|
||||
tokens = []
|
||||
|
||||
Reference in New Issue
Block a user