From 2d4f01fda02fbfb162a822cb53494e4635f5b6f8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:18:33 -0800 Subject: [PATCH] move mixins to mixin dir (#13105) * move mixins to mixin dir * math --- extra/gemm/simple_matmul.py | 2 +- setup.py | 1 + test/test_uops.py | 2 +- tinygrad/mixin/__init__.py | 4 ++ tinygrad/{uop/mixins.py => mixin/math.py} | 80 +---------------------- tinygrad/mixin/movement.py | 80 +++++++++++++++++++++++ tinygrad/tensor.py | 4 +- tinygrad/uop/ops.py | 6 +- 8 files changed, 93 insertions(+), 86 deletions(-) create mode 100644 tinygrad/mixin/__init__.py rename tinygrad/{uop/mixins.py => mixin/math.py} (73%) create mode 100644 tinygrad/mixin/movement.py diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 5a9f2da940..45a359be38 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -17,7 +17,7 @@ M = getenv("M", N) K = getenv("K", N) CNT = getenv("CNT", 10) -atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2)) +atol, rtol = {dtypes.half:{1e-3, 1e-2}, dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2)) ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol) INT_LOW = getenv("INT_LOW", 0) diff --git a/setup.py b/setup.py index 2624d21c34..412209a8de 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ setup(name='tinygrad', 'tinygrad.codegen.opt', 'tinygrad.codegen.late', 'tinygrad.engine', + 'tinygrad.mixin', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.runtime', diff --git a/test/test_uops.py b/test/test_uops.py index b29137a015..64fd5bee6f 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase): class TestUPatHelpers(unittest.TestCase): def test_location(self): - self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "mixins.py") + self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "math.py") self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py new file mode 100644 index 0000000000..d33a9eb479 --- /dev/null +++ b/tinygrad/mixin/__init__.py @@ -0,0 +1,4 @@ +from tinygrad.mixin.math import MathMixin +from tinygrad.mixin.movement import MovementMixin + +class OpMixin(MathMixin, MovementMixin): pass \ No newline at end of file diff --git a/tinygrad/uop/mixins.py b/tinygrad/mixin/math.py similarity index 73% rename from tinygrad/uop/mixins.py rename to tinygrad/mixin/math.py index e2279146a6..10cfa3a5b5 100644 --- a/tinygrad/uop/mixins.py +++ b/tinygrad/mixin/math.py @@ -1,11 +1,6 @@ -# mixins add syntactic sugar to Tensor and UOp -from typing import TypeAlias, TYPE_CHECKING, Self +from typing import Self from tinygrad.uop import Ops from tinygrad.dtype import dtypes, ConstType -from tinygrad.helpers import prod, argfix -if TYPE_CHECKING: - from tinygrad.uop.ops import UOp - sint:TypeAlias = UOp|int class MathMixin: # required to implement @@ -175,76 +170,3 @@ class MathMixin: def exp2(self): return self.alu(Ops.EXP2) def pow(self, x:Self|ConstType): return self.alu(Ops.POW, self.ufix(x)) def __pow__(self, x:Self|ConstType): return self.pow(x) - -class MovementMixin: - # required to implement - def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError - @property - def shape(self) -> tuple["sint", ...]: raise NotImplementedError - - # great functions you get! - @property - def ndim(self) -> int: - """ - Returns the number of dimensions in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2], [3, 4]]) - print(t.ndim) - ``` - """ - return len(self.shape) - - def numel(self) -> "sint": - """ - Returns the total number of elements in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - print(t.numel()) - ``` - """ - return prod(self.shape) - - def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: - total = self.ndim + int(extra) - if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") - return dim + total if dim < 0 else dim - - def view(self, shape, *args) -> Self: - """`.view` is an alias for `.reshape`.""" - return self.reshape(shape, *args) - - def reshape(self, shape, *args) -> Self: - """ - Returns a tensor with the same data as the original tensor but with a different shape. - `shape` can be passed as a tuple or as separate arguments. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6) - print(t.reshape(2, 3).numpy()) - ``` - """ - # resolve None and args - new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))]) - # resolve -1 - if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") - if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self - - def flatten(self, start_dim=0, end_dim=-1) -> Self: - """ - Flattens the tensor by reshaping it into a one-dimensional tensor. - If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(8).reshape(2, 2, 2) - print(t.flatten().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.flatten(start_dim=1).numpy()) - ``` - """ - start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) - return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) \ No newline at end of file diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py new file mode 100644 index 0000000000..c6b9fba19b --- /dev/null +++ b/tinygrad/mixin/movement.py @@ -0,0 +1,80 @@ +# mixins add syntactic sugar to Tensor and UOp +from typing import TypeAlias, TYPE_CHECKING, Self +from tinygrad.uop import Ops +from tinygrad.helpers import prod, argfix +if TYPE_CHECKING: + from tinygrad.uop.ops import UOp + sint:TypeAlias = UOp|int + +class MovementMixin: + # required to implement + def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError + @property + def shape(self) -> tuple["sint", ...]: raise NotImplementedError + + # great functions you get! + @property + def ndim(self) -> int: + """ + Returns the number of dimensions in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2], [3, 4]]) + print(t.ndim) + ``` + """ + return len(self.shape) + + def numel(self) -> "sint": + """ + Returns the total number of elements in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + print(t.numel()) + ``` + """ + return prod(self.shape) + + def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") + return dim + total if dim < 0 else dim + + def view(self, shape, *args) -> Self: + """`.view` is an alias for `.reshape`.""" + return self.reshape(shape, *args) + + def reshape(self, shape, *args) -> Self: + """ + Returns a tensor with the same data as the original tensor but with a different shape. + `shape` can be passed as a tuple or as separate arguments. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(6) + print(t.reshape(2, 3).numpy()) + ``` + """ + # resolve None and args + new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))]) + # resolve -1 + if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) + if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") + return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self + + def flatten(self, start_dim=0, end_dim=-1) -> Self: + """ + Flattens the tensor by reshaping it into a one-dimensional tensor. + If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(8).reshape(2, 2, 2) + print(t.flatten().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.flatten(start_dim=1).numpy()) + ``` + """ + start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) + return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d8c0476d0b..c49c510f21 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient -from tinygrad.uop.mixins import MathMixin, MovementMixin +from tinygrad.mixin import OpMixin from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Device, Buffer @@ -100,7 +100,7 @@ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: r ReductionStr = Literal["mean", "sum", "none"] -class Tensor(MathMixin, MovementMixin): +class Tensor(OpMixin): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 34f48e0750..bee1477955 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick from dataclasses import dataclass from enum import Enum, auto from tinygrad.uop import Ops, GroupOp -from tinygrad.uop.mixins import MathMixin, MovementMixin +from tinygrad.mixin import OpMixin from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI @@ -104,7 +104,7 @@ class recursive_property(property): # NOTE: this should be frozen, but frozen is slower @dataclass(eq=False, slots=True) -class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass): +class UOp(OpMixin, metaclass=UOpMetaClass): op:Ops dtype:DType = dtypes.void src:tuple[UOp, ...] = tuple() @@ -867,7 +867,7 @@ def printable(loc:tuple[str, int]) -> str: try: return lines(loc[0])[loc[1]-1].strip() except FileNotFoundError: return "" -class UPat(MathMixin, MovementMixin): +class UPat(OpMixin): __slots__ = ("op", "dtype", "arg", "name", "src") def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None, src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,