cleanup hack tricks

This commit is contained in:
David Hou
2024-02-08 14:39:37 -08:00
parent f05f08e721
commit 605cdef164
4 changed files with 6 additions and 13 deletions

View File

@@ -1,5 +1,5 @@
import math
from typing import List
from typing import List, Union
from tinygrad.nn.optim import Optimizer
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
@@ -9,11 +9,11 @@ class LR_Scheduler:
self.optimizer = optimizer
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device, dtype=dtypes.float32)
def get_lr(self): pass
def get_lr(self) -> Union[Tensor, int]: pass
def step(self) -> None:
self.epoch_counter.assign(self.epoch_counter + 1).realize()
self.optimizer.lr.assign(self.get_lr()).realize()
self.optimizer.lr.assign(lr.cast(self.optimizer.lr.dtype) if isinstance(lr := self.get_lr(), Tensor) else lr).realize()
class MultiStepLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):

View File

@@ -140,7 +140,7 @@ class ResNet:
if is_feature_only: features.append(out)
if not is_feature_only:
out = out.mean([2,3])
out = self.fc(out).float().log_softmax()
out = self.fc(out).log_softmax()
return out
return features

View File

@@ -285,7 +285,6 @@ class CompiledASTRunner(JITRunner):
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
self.vars: List[Variable] = []
if ast:
self.ast = ast
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
self.vars = ast.vars()
@@ -297,10 +296,6 @@ class CompiledASTRunner(JITRunner):
return global_size, local_size
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]:
if GlobalCounters.kernel_count+1 == getenv("DEBUGK", -1):
from tinygrad.graph import print_tree
print_tree(self.ast)
print(self.prg)
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program

View File

@@ -39,7 +39,7 @@ class SGD(Optimizer):
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr.cast(g.dtype))
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
@@ -92,11 +92,9 @@ class LARS(Optimizer):
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps), 1.0
), 1.0
)
scaled_lr = self.lr * trust_ratio
g = (t.grad + self.weight_decay * t.detach())
g = g * scaled_lr.cast(g.dtype)
g = (t.grad + self.weight_decay * t.detach()) * scaled_lr
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]