mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanup hack tricks
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -9,11 +9,11 @@ class LR_Scheduler:
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device, dtype=dtypes.float32)
|
||||
|
||||
def get_lr(self): pass
|
||||
def get_lr(self) -> Union[Tensor, int]: pass
|
||||
|
||||
def step(self) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
self.optimizer.lr.assign(lr.cast(self.optimizer.lr.dtype) if isinstance(lr := self.get_lr(), Tensor) else lr).realize()
|
||||
|
||||
class MultiStepLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||
|
||||
@@ -140,7 +140,7 @@ class ResNet:
|
||||
if is_feature_only: features.append(out)
|
||||
if not is_feature_only:
|
||||
out = out.mean([2,3])
|
||||
out = self.fc(out).float().log_softmax()
|
||||
out = self.fc(out).log_softmax()
|
||||
return out
|
||||
return features
|
||||
|
||||
|
||||
@@ -285,7 +285,6 @@ class CompiledASTRunner(JITRunner):
|
||||
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
|
||||
self.vars: List[Variable] = []
|
||||
if ast:
|
||||
self.ast = ast
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
self.vars = ast.vars()
|
||||
@@ -297,10 +296,6 @@ class CompiledASTRunner(JITRunner):
|
||||
return global_size, local_size
|
||||
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
if GlobalCounters.kernel_count+1 == getenv("DEBUGK", -1):
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(self.ast)
|
||||
print(self.prg)
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
|
||||
# TODO: this is copied from get_program
|
||||
|
||||
@@ -39,7 +39,7 @@ class SGD(Optimizer):
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr.cast(g.dtype))
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
|
||||
@@ -92,11 +92,9 @@ class LARS(Optimizer):
|
||||
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps), 1.0
|
||||
), 1.0
|
||||
)
|
||||
|
||||
scaled_lr = self.lr * trust_ratio
|
||||
|
||||
g = (t.grad + self.weight_decay * t.detach())
|
||||
g = g * scaled_lr.cast(g.dtype)
|
||||
g = (t.grad + self.weight_decay * t.detach()) * scaled_lr
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
|
||||
Reference in New Issue
Block a user