From 605cdef16442f1c6e99a356dbed6c25df8348659 Mon Sep 17 00:00:00 2001 From: David Hou Date: Thu, 8 Feb 2024 14:39:37 -0800 Subject: [PATCH] cleanup hack tricks --- extra/lr_scheduler.py | 6 +++--- extra/models/resnet.py | 2 +- tinygrad/device.py | 5 ----- tinygrad/nn/optim.py | 6 ++---- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/extra/lr_scheduler.py b/extra/lr_scheduler.py index 3aee279862..e177e519a4 100644 --- a/extra/lr_scheduler.py +++ b/extra/lr_scheduler.py @@ -1,5 +1,5 @@ import math -from typing import List +from typing import List, Union from tinygrad.nn.optim import Optimizer from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes @@ -9,11 +9,11 @@ class LR_Scheduler: self.optimizer = optimizer self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device, dtype=dtypes.float32) - def get_lr(self): pass + def get_lr(self) -> Union[Tensor, int]: pass def step(self) -> None: self.epoch_counter.assign(self.epoch_counter + 1).realize() - self.optimizer.lr.assign(self.get_lr()).realize() + self.optimizer.lr.assign(lr.cast(self.optimizer.lr.dtype) if isinstance(lr := self.get_lr(), Tensor) else lr).realize() class MultiStepLR(LR_Scheduler): def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1): diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 011c410fc3..d23187b338 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -140,7 +140,7 @@ class ResNet: if is_feature_only: features.append(out) if not is_feature_only: out = out.mean([2,3]) - out = self.fc(out).float().log_softmax() + out = self.fc(out).log_softmax() return out return features diff --git a/tinygrad/device.py b/tinygrad/device.py index a5a0ad0599..cb6143b972 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -285,7 +285,6 @@ class CompiledASTRunner(JITRunner): self.lib, self.clprg = lib, self.device.runtime(self.name, lib) self.vars: List[Variable] = [] if ast: - self.ast = ast info = get_lazyop_info(ast) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate self.vars = ast.vars() @@ -297,10 +296,6 @@ class CompiledASTRunner(JITRunner): return global_size, local_size def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]: - if GlobalCounters.kernel_count+1 == getenv("DEBUGK", -1): - from tinygrad.graph import print_tree - print_tree(self.ast) - print(self.prg) global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 1a8532fdd9..1aa03a8166 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -39,7 +39,7 @@ class SGD(Optimizer): if self.momentum: self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] - t.assign(t.detach() - g * self.lr.cast(g.dtype)) + t.assign(t.detach() - g * self.lr) self.realize(self.b) # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W. @@ -92,11 +92,9 @@ class LARS(Optimizer): self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps), 1.0 ), 1.0 ) - scaled_lr = self.lr * trust_ratio - g = (t.grad + self.weight_decay * t.detach()) - g = g * scaled_lr.cast(g.dtype) + g = (t.grad + self.weight_decay * t.detach()) * scaled_lr if self.momentum: self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]