Faster UOp hashing (#6447)

* Faster hashing of Enums and UOp

* NOp should not define __eq__

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Tim Becker
2024-09-09 16:16:04 -07:00
committed by GitHub
parent 92e4126793
commit 58a1b4f427
2 changed files with 15 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
from enum import Enum
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@@ -353,3 +354,8 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
# small wrapper around Enum that caches the __hash__ method
class HashEnum(Enum):
def __init__(self, *_): self._cached_hash = Enum.__hash__(self)
def __hash__(self): return self._cached_hash

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, Sequence, DefaultDict
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
from enum import Enum, auto
from enum import auto
from collections import defaultdict
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
from tinygrad.helpers import pretty_print, prod, getenv, all_same
from tinygrad.helpers import pretty_print, prod, getenv, all_same, HashEnum
from tinygrad.shape.symbolic import Variable, sint
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -13,20 +13,20 @@ if TYPE_CHECKING:
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(Enum):
class UnaryOps(HashEnum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
class BinaryOps(Enum):
class BinaryOps(HashEnum):
"""A + A -> A (elementwise)"""
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
class TernaryOps(Enum):
class TernaryOps(HashEnum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum):
class ReduceOps(HashEnum):
"""A -> B (reduce)"""
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
class MetaOps(Enum):
class MetaOps(HashEnum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
@@ -75,7 +75,7 @@ REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
# the order of these UOps controls the order of the toposort
class UOps(Enum):
class UOps(HashEnum):
# uops that aren't rendered
SINK = auto()
"""
@@ -330,7 +330,6 @@ class UOp(MathTrait):
dtype: Optional[DType] = None
src: Tuple[UOp, ...] = tuple()
arg: Any = None
def __hash__(self): return id(self)
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
@@ -580,7 +579,7 @@ def get_location() -> Tuple[str, int]:
@functools.lru_cache(None)
def lines(fn) -> List[str]: return open(fn).readlines()
@dataclass(frozen=True, repr=False) # reuse repr from UOp
@dataclass(frozen=True, eq=False, repr=False) # reuse repr from UOp
class NOp(UOp):
name: Optional[str] = None
src: Tuple[NOp, ...] = tuple()