From a5a55ac19e30eb146195ee535e797324400df02f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 8 Feb 2023 17:10:55 -0600 Subject: [PATCH] GlobalCounters cache + assign in optim --- examples/benchmark_train_efficientnet.py | 7 ++----- tinygrad/llops/ops_gpu.py | 6 ++++-- tinygrad/nn/optim.py | 6 +++--- tinygrad/ops.py | 5 +++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py index 5d31e5cb0d..5eb128be9f 100644 --- a/examples/benchmark_train_efficientnet.py +++ b/examples/benchmark_train_efficientnet.py @@ -40,8 +40,7 @@ if __name__ == "__main__": st = time.monotonic() out = model.forward(x_train) loss = out.logsoftmax().mul(y_train).mean() - if ADAM: optimizer.t = 0 # TODO: fixing this requires optional constant folding - if i == 2 and CLCACHE: CL.CACHE = [] + if i == 2 and CLCACHE: GlobalCounters.cache = [] if BACKWARD: optimizer.zero_grad() loss.backward() @@ -51,15 +50,13 @@ if __name__ == "__main__": for p in parameters: p.realize() et = time.monotonic() - ops = GlobalCounters.global_ops else: st = mt = time.monotonic() for prg, args in cl_cache: prg(*args) et = time.monotonic() if i == 2 and CLCACHE: - cl_cache = CL.CACHE - CL.CACHE = None + cl_cache = GlobalCounters.cache mem_used = CL.mem_used loss_cpu = loss.detach().numpy()[0] diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index c91e1286c4..3a508f8cfa 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -2,7 +2,7 @@ from __future__ import annotations import numpy as np from typing import List, Tuple, Optional, Dict, Union, Set, Final, Callable from tinygrad.helpers import prod -from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST +from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters from tinygrad.ast import ASTKernel, Token, Types from tinygrad.lazy import IMAGE from tinygrad.shape import ShapeTracker @@ -361,7 +361,9 @@ class GPUBuffer(ExplicitExecAST): if KOPT: from extra.kernel_search import apply_optimization apply_optimization(k, ast, max_interventions=KOPT) - k.codegen()(*k.bufs) + prg = k.codegen() + if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs)) + prg(*k.bufs) if PRINT_AST == "1" or (hasattr(k, "fxn") and PRINT_AST == k.fxn.name): print(k.fxn.name) k.print() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index b5d0e81a51..0c5eed2e37 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -48,7 +48,7 @@ class RMSprop(Optimizer): def step(self) -> None: for i, t in enumerate(self.params): assert t.grad is not None - self.v[i] = self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad) + self.v[i].assign(self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad)) t.assign(t.detach() - (t.grad * self.lr).div(self.v[i].sqrt() + self.eps)) self.realize(self.v) @@ -65,8 +65,8 @@ class Adam(Optimizer): a = self.lr * ((1.0 - self.b2**self.t)**0.5) / (1.0 - self.b1**self.t) for i, t in enumerate(self.params): assert t.grad is not None - self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * t.grad - self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad) + self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) + self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) self.realize([self.t] + self.m + self.v) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b4e98db32c..67c0aefdb8 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np from enum import Enum, auto -from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar +from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional import functools, operator from tinygrad.helpers import prod, shape_to_axis from tinygrad.shape import ShapeTracker @@ -89,8 +89,9 @@ class GlobalCounters: global_mem : ClassVar[int] = 0 time_sum : ClassVar[int] = 0 kernel_count : ClassVar[int] = 0 + cache : ClassVar[Optional[list]] = None @staticmethod - def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count = 0,0,0,0 + def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None class GenericShape(GenericExecAST): # pylint: disable=abstract-method def __init__(self, shape, flops=0): self.shape, self.flops = shape, flops