move GlobalCounter to helpers (#4002)

break circular import between ops and buffer
This commit is contained in:
chenyu
2024-03-30 00:30:30 -04:00
committed by GitHub
parent 9eef44521b
commit c71627fee6
19 changed files with 40 additions and 66 deletions

View File

@@ -2,5 +2,5 @@ from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.ops import GlobalCounters # noqa: F401
from tinygrad.helpers import GlobalCounters # noqa: F401
from tinygrad.device import Device # noqa: F401

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
from typing import Any, Optional
from dataclasses import dataclass
from tinygrad.helpers import flat_mv
from tinygrad.helpers import GlobalCounters, flat_mv
from tinygrad.dtype import DType, ImageDType
from tinygrad.ops import GlobalCounters
@dataclass(frozen=True, eq=True)
class BufferOptions:

View File

@@ -2,10 +2,10 @@ from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple
import importlib, inspect, functools, pathlib, time, ctypes
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import prod, CACHECOLLECTING
from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import DEBUG, CACHECOLLECTING, BEAM, NOOPT, GlobalCounters
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, get_lazyop_info, GlobalCounters
from tinygrad.ops import LazyOp, get_lazyop_info
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.codegen.uops import UOpGraph

View File

@@ -2,9 +2,9 @@ import sys
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer

View File

@@ -1,9 +1,9 @@
import os, atexit, functools
from collections import defaultdict
from typing import List, Any, DefaultDict, TYPE_CHECKING
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp, GlobalCounters
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, getenv
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer

View File

@@ -98,6 +98,19 @@ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), Cont
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
GRAPH, GRAPHPATH, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("RING", 1)
# **************** global state Counters ****************
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# **************** timer and profiler ****************
class Timing(contextlib.ContextDecorator):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self): self.st = time.perf_counter_ns()

View File

@@ -2,9 +2,8 @@ import os, json, pathlib, zipfile, pickle, tarfile, struct
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
from tinygrad.shape.view import strides_for_shape
from tinygrad.features.multi import MultiLazyBuffer

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, ClassVar
from typing import Union, Type, Tuple, Any, List, Dict, Callable
import functools, hashlib
from enum import Enum, auto
from dataclasses import dataclass
@@ -7,6 +7,7 @@ from tinygrad.helpers import prod, dedup
from tinygrad.dtype import dtypes, DType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.buffer import Buffer
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
@@ -23,9 +24,6 @@ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS =
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
if TYPE_CHECKING:
from tinygrad.buffer import Buffer
@dataclass(frozen=True)
class MemBuffer:
idx: int
@@ -98,14 +96,3 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** global state Counters ****************
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0