mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Faster UOp hashing (#6447)
* Faster hashing of Enums and UOp * NOp should not define __eq__ --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
@@ -353,3 +354,8 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
# small wrapper around Enum that caches the __hash__ method
|
||||
class HashEnum(Enum):
|
||||
def __init__(self, *_): self._cached_hash = Enum.__hash__(self)
|
||||
def __hash__(self): return self._cached_hash
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, Sequence, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
|
||||
from enum import Enum, auto
|
||||
from enum import auto
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
||||
from tinygrad.helpers import pretty_print, prod, getenv, all_same
|
||||
from tinygrad.helpers import pretty_print, prod, getenv, all_same, HashEnum
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -13,20 +13,20 @@ if TYPE_CHECKING:
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(Enum):
|
||||
class UnaryOps(HashEnum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
|
||||
class BinaryOps(Enum):
|
||||
class BinaryOps(HashEnum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
|
||||
class TernaryOps(Enum):
|
||||
class TernaryOps(HashEnum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
class ReduceOps(Enum):
|
||||
class ReduceOps(HashEnum):
|
||||
"""A -> B (reduce)"""
|
||||
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
|
||||
class MetaOps(Enum):
|
||||
class MetaOps(HashEnum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
||||
|
||||
@@ -75,7 +75,7 @@ REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps
|
||||
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
class UOps(HashEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
"""
|
||||
@@ -330,7 +330,6 @@ class UOp(MathTrait):
|
||||
dtype: Optional[DType] = None
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def __hash__(self): return id(self)
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
@@ -580,7 +579,7 @@ def get_location() -> Tuple[str, int]:
|
||||
@functools.lru_cache(None)
|
||||
def lines(fn) -> List[str]: return open(fn).readlines()
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
@dataclass(frozen=True, eq=False, repr=False) # reuse repr from UOp
|
||||
class NOp(UOp):
|
||||
name: Optional[str] = None
|
||||
src: Tuple[NOp, ...] = tuple()
|
||||
|
||||
Reference in New Issue
Block a user