mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
GlobalCounters cache + assign in optim
This commit is contained in:
@@ -40,8 +40,7 @@ if __name__ == "__main__":
|
||||
st = time.monotonic()
|
||||
out = model.forward(x_train)
|
||||
loss = out.logsoftmax().mul(y_train).mean()
|
||||
if ADAM: optimizer.t = 0 # TODO: fixing this requires optional constant folding
|
||||
if i == 2 and CLCACHE: CL.CACHE = []
|
||||
if i == 2 and CLCACHE: GlobalCounters.cache = []
|
||||
if BACKWARD:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@@ -51,15 +50,13 @@ if __name__ == "__main__":
|
||||
for p in parameters:
|
||||
p.realize()
|
||||
et = time.monotonic()
|
||||
ops = GlobalCounters.global_ops
|
||||
else:
|
||||
st = mt = time.monotonic()
|
||||
for prg, args in cl_cache: prg(*args)
|
||||
et = time.monotonic()
|
||||
|
||||
if i == 2 and CLCACHE:
|
||||
cl_cache = CL.CACHE
|
||||
CL.CACHE = None
|
||||
cl_cache = GlobalCounters.cache
|
||||
|
||||
mem_used = CL.mem_used
|
||||
loss_cpu = loss.detach().numpy()[0]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set, Final, Callable
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.ast import ASTKernel, Token, Types
|
||||
from tinygrad.lazy import IMAGE
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -361,7 +361,9 @@ class GPUBuffer(ExplicitExecAST):
|
||||
if KOPT:
|
||||
from extra.kernel_search import apply_optimization
|
||||
apply_optimization(k, ast, max_interventions=KOPT)
|
||||
k.codegen()(*k.bufs)
|
||||
prg = k.codegen()
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
|
||||
prg(*k.bufs)
|
||||
if PRINT_AST == "1" or (hasattr(k, "fxn") and PRINT_AST == k.fxn.name):
|
||||
print(k.fxn.name)
|
||||
k.print()
|
||||
|
||||
@@ -48,7 +48,7 @@ class RMSprop(Optimizer):
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.v[i] = self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad)
|
||||
self.v[i].assign(self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad))
|
||||
t.assign(t.detach() - (t.grad * self.lr).div(self.v[i].sqrt() + self.eps))
|
||||
self.realize(self.v)
|
||||
|
||||
@@ -65,8 +65,8 @@ class Adam(Optimizer):
|
||||
a = self.lr * ((1.0 - self.b2**self.t)**0.5) / (1.0 - self.b1**self.t)
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * t.grad
|
||||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional
|
||||
import functools, operator
|
||||
from tinygrad.helpers import prod, shape_to_axis
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -89,8 +89,9 @@ class GlobalCounters:
|
||||
global_mem : ClassVar[int] = 0
|
||||
time_sum : ClassVar[int] = 0
|
||||
kernel_count : ClassVar[int] = 0
|
||||
cache : ClassVar[Optional[list]] = None
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count = 0,0,0,0
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None
|
||||
|
||||
class GenericShape(GenericExecAST): # pylint: disable=abstract-method
|
||||
def __init__(self, shape, flops=0): self.shape, self.flops = shape, flops
|
||||
|
||||
Reference in New Issue
Block a user