From 012ee7d162a809481f31d60c4fba4ddbdb52ded3 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 20 Aug 2023 10:24:58 -0700 Subject: [PATCH] not worth the speed (#1584) * not worth the speed * no slots * uops comments * bump to python 3.11 for speed * add critical slots back --- .github/workflows/test.yml | 24 ++++---- test/unit/test_weak.py | 103 --------------------------------- tinygrad/codegen/assembly.py | 2 - tinygrad/codegen/linearizer.py | 10 +++- tinygrad/helpers.py | 37 ------------ tinygrad/lazy.py | 9 ++- tinygrad/mlops.py | 18 ------ tinygrad/ops.py | 1 - 8 files changed, 24 insertions(+), 180 deletions(-) delete mode 100644 test/unit/test_weak.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 001d2ffc7a..37ca1b3111 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,10 +14,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: @@ -48,10 +48,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: @@ -84,10 +84,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: @@ -121,10 +121,10 @@ jobs: echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list sudo apt update sudo apt install -y --no-install-recommends intel-oneapi-runtime-compilers intel-oneapi-runtime-opencl - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: @@ -209,10 +209,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: @@ -279,10 +279,10 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Cache pip uses: actions/cache@v3 with: diff --git a/test/unit/test_weak.py b/test/unit/test_weak.py deleted file mode 100644 index 3251257d8a..0000000000 --- a/test/unit/test_weak.py +++ /dev/null @@ -1,103 +0,0 @@ -from tinygrad.helpers import LightWeakSet, LightWeakValueDictionary -import unittest -import time - -CNT = 1000 - -cnt = 0 -class MyObject: - def __init__(self): - global cnt - self.cnt = cnt - cnt += 1 - #print(f"object {self.cnt} created") - #def __del__(self): print(f"object {self.cnt} destroyed") - -class TestWeak(unittest.TestCase): - def test_set_drops(self): - ss = LightWeakSet() - ss.add(MyObject()) - assert len(ss) == 0 - - def test_set_holds(self): - ss = LightWeakSet() - obj = MyObject() - ss.add(obj) - assert len(ss) == 1 - - def test_set_late_drops(self): - ss = LightWeakSet() - obj = MyObject() - ss.add(obj) - assert len(ss) == 1 - del obj - assert len(ss) == 0 - - def test_dict_drops(self): - dd = LightWeakValueDictionary() - dd[0] = MyObject() - assert 0 not in dd - - def test_dict_holds(self): - dd = LightWeakValueDictionary() - dd[0] = ret = MyObject() - assert 0 in dd - - def test_a_myobj_microbench(self): - for _ in range(3): - st = time.perf_counter_ns() - for _ in range(CNT): - obj = MyObject() - et = (time.perf_counter_ns() - st)/CNT - print(f"{et:.2f} ns to create MyObject") - - def test_set_add_microbench(self): - for _ in range(3): - ss = LightWeakSet() - st = time.perf_counter_ns() - for _ in range(CNT): - obj = MyObject() - ss.add(obj) - assert len(ss) == 1 - et = (time.perf_counter_ns() - st)/CNT - print(f"{et:.2f} ns to add to LightWeakSet") - - def test_set_del_microbench(self): - for _ in range(3): - ss = LightWeakSet() - st = time.perf_counter_ns() - for _ in range(CNT): - obj = MyObject() - ss.add(obj) - ss.discard(obj) - assert len(ss) == 0 - et = (time.perf_counter_ns() - st)/CNT - print(f"{et:.2f} ns to add/del from LightWeakSet") - - def test_dict_add_microbench(self): - for _ in range(3): - dd = LightWeakValueDictionary() - st = time.perf_counter_ns() - for i in range(CNT): - obj = MyObject() - dd[i] = obj - assert len(dd) == 1 - et = (time.perf_counter_ns() - st)/CNT - print(f"{et:.2f} ns to add to LightWeakDict") - - def test_dict_check_microbench(self): - for _ in range(3): - dd = LightWeakValueDictionary() - st = time.perf_counter_ns() - for i in range(CNT): - obj = MyObject() - dd[i] = obj - assert i in dd - tst = dd[i] - del obj,tst - assert len(dd) == 0 - et = (time.perf_counter_ns() - st)/CNT - print(f"{et:.2f} ns to add/del from LightWeakDict") - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py index 5ab2179359..666a297343 100644 --- a/tinygrad/codegen/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -184,8 +184,6 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): elif uop == UOps.STORE: idx, treg, off = lang.addr_w_offset(args) lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) - # define registers - lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, lang.type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins if DEBUG >= 4: for tins in lang.ins: print(tins) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 0da9ee00d5..09687ff8c8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -13,8 +13,14 @@ from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, s VariableOrNum = Union[Variable, NumNode, Node] # bottom ones are asm only -class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); DEFINE_GLOBAL = auto(); LOAD = auto(); ALU = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); WMMA = auto(); \ - SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702 +class UOps(Enum): + LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702 + DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto() # this defines buffers # noqa: E702 + LOAD = auto(); STORE = auto(); BARRIER = auto() # noqa: E702 + ALU = auto(); WMMA = auto(); CAST = auto() # noqa: E702 + # TODO: add CONST. use ALU WHERE for gated load + # *** assembly only UOps *** + SPECIAL = auto(); LABEL = auto(); COND_BRANCH = auto() # TODO: replace these with LOOP and ENDLOOP # noqa: E702 def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: idy = (idxy//(4*base_shape[1])) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a54304cef3..c25bfc6a23 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,7 +1,5 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib -from weakref import KeyedRef, ref -from _weakref import _remove_dead_weakref # type: ignore import numpy as np from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable from math import prod # noqa: F401 # pylint:disable=unused-import @@ -42,7 +40,6 @@ class Context(contextlib.ContextDecorator): class ContextVar: _cache: ClassVar[Dict[str, ContextVar]] = {} - __slots__ = "value" value: int def __new__(cls, key, default_value): if key in ContextVar._cache: return ContextVar._cache[key] @@ -134,37 +131,3 @@ class GlobalCounters: cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]] @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None - -# Stripped down version of a WeakSet -class LightWeakSet: - __slots__ = 'data', '_remove', '__weakref__' - def __init__(self): - self.data = set() - def _remove(item, selfref=ref(self)): - self = selfref() - if self: self.data.discard(item) - self._remove = _remove - - def __len__(self): return len(self.data) - def add(self, item): self.data.add(ref(item, self._remove)) - def discard(self, item): self.data.discard(ref(item)) - -# Stripped down version of a WeakValueDictionary -class LightWeakValueDictionary: - __slots__ = 'data', '_remove', '__weakref__' - def __init__(self): - def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref): - self = selfref() - if self: _atomic_removal(self.data, wr.key) - self._remove = remove - self.data = {} - - def __getitem__(self, key): - o = self.data[key]() - if o is None: raise KeyError(key) - else: return o - - def __len__(self): return len(self.data) - def __delitem__(self, key): del self.data[key] - def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key) - def __contains__(self, key): return key in self.data diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index ceb4d4e12d..1e2dcbc470 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -2,10 +2,10 @@ from __future__ import annotations import operator, math from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast import sys, importlib, inspect, functools, pathlib -from weakref import ref +from weakref import ref, WeakSet, WeakValueDictionary import numpy as np -from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary +from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction @@ -94,7 +94,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast( def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -lazycache: LightWeakValueDictionary = LightWeakValueDictionary() +lazycache: WeakValueDictionary = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType): # fromcpu aren't cached if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype) @@ -109,7 +109,6 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty return ret class LazyBuffer: - __slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__' __deletable__ = ('op',) def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None): self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker @@ -117,7 +116,7 @@ class LazyBuffer: self.realized: Optional[RawBuffer] = src self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? - self.children: LightWeakSet = LightWeakSet() + self.children: WeakSet = WeakSet() # NOTE: op should be read only after construction of LazyBuffer self.op: LazyOp = op for x in op.buffers: x.children.add(self) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index a6b278c0f3..50d8104683 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -10,7 +10,6 @@ class Contiguous(Function): def backward(self, grad_output): return grad_output class Cast(Function): - __slots__ = "input_dtype", "bitcast" def forward(self, x:LazyBuffer, dtype:DType, bitcast=False): self.input_dtype, self.bitcast = x.dtype, bitcast return x.cast((dtype, bitcast)) @@ -20,7 +19,6 @@ class Cast(Function): # ************* unary ops ************* class Sin(Function): - __slots__ = "x" def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.unary_op(UnaryOps.SIN) @@ -29,7 +27,6 @@ class Sin(Function): # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): - __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.binary_op(BinaryOps.MAX, 0) return self.ret @@ -38,7 +35,6 @@ class Relu(Function): return (0 < self.ret) * grad_output class Log(Function): - __slots__ = "x" def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.unary_op(UnaryOps.LOG2) * math.log(2) @@ -47,7 +43,6 @@ class Log(Function): return grad_output / self.x class Exp(Function): - __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2) return self.ret @@ -56,7 +51,6 @@ class Exp(Function): return self.ret * grad_output class Sqrt(Function): - __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.unary_op(UnaryOps.SQRT) return self.ret @@ -68,7 +62,6 @@ class Sqrt(Function): # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e # TODO: have the backend automatically find this class Sigmoid(Function): - __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2)) return self.ret @@ -79,7 +72,6 @@ class Sigmoid(Function): # ************* reduce ops ************* class Sum(Function): - __slots__ = "input_shape" def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape return x.reduce_op(ReduceOps.SUM, new_shape) @@ -88,7 +80,6 @@ class Sum(Function): return grad_output.expand(self.input_shape) class Max(Function): - __slots__ = "x", "ret" def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer: self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape) return self.ret @@ -122,7 +113,6 @@ class Sub(Function): -grad_output if self.needs_input_grad[1] else None class Mul(Function): - __slots__ = 'x', 'y' def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x * y @@ -132,7 +122,6 @@ class Mul(Function): self.x * grad_output if self.needs_input_grad[1] else None class Div(Function): - __slots__ = 'x', 'y' def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x / y @@ -144,7 +133,6 @@ class Div(Function): # ************* ternary ops ************* class Where(Function): - __slots__ = "x" def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: self.x = x return x.ternary_op(TernaryOps.WHERE, y, z) @@ -158,7 +146,6 @@ class Where(Function): # NOTE: this is sum in reverse class Expand(Function): - __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape return x.expand(shape) @@ -167,7 +154,6 @@ class Expand(Function): return grad_output.reduce_op(ReduceOps.SUM, self.input_shape) class Reshape(Function): - __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape return x.reshape(shape) @@ -176,7 +162,6 @@ class Reshape(Function): return grad_output.reshape(self.input_shape) class Permute(Function): - __slots__ = 'input_order' def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order return x.permute(order) @@ -185,7 +170,6 @@ class Permute(Function): return grad_output.permute(argsort(self.input_order)) class Pad(Function): - __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) return x.pad(arg) @@ -194,7 +178,6 @@ class Pad(Function): return grad_output.shrink(self.narg) class Shrink(Function): - __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) @@ -203,7 +186,6 @@ class Shrink(Function): return grad_output.pad(self.narg) class Flip(Function): - __slots__ = 'arg' def forward(self, x:LazyBuffer, axis:Tuple[int, ...]): self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) return x.stride(self.arg) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 99cdc55840..acc6830fdd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -23,7 +23,6 @@ Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]] class LazyOp: - # TODO: add dest to support multiple outputs. on second thought, multiple outputs will have multiple LazyOps. __slots__ = "op", "src", "arg", "buffers", "__weakref__" op: Op src: Tuple[Union[LazyOp, LazyBuffer], ...]