From c9a1e35b1e60c09e8ba26cb84d5b2e64f9368954 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 3 Nov 2025 12:00:45 +0800 Subject: [PATCH] slicing + allclose --- test/test_custom_kernel.py | 29 +++++++++++++++++---- tinygrad/codegen/__init__.py | 10 ++++++-- tinygrad/tensor.py | 50 +++++------------------------------- tinygrad/uop/mixins.py | 44 +++++++++++++++++++++++++++++++ tinygrad/uop/ops.py | 18 +++++++++++-- 5 files changed, 98 insertions(+), 53 deletions(-) diff --git a/test/test_custom_kernel.py b/test/test_custom_kernel.py index b779ab3868..32bd0366fd 100644 --- a/test/test_custom_kernel.py +++ b/test/test_custom_kernel.py @@ -1,5 +1,6 @@ import unittest from tinygrad import Tensor, UOp, Context +from tinygrad.dtype import AddrSpace from tinygrad.uop.ops import KernelInfo, AxisType # **** kernels **** @@ -9,15 +10,18 @@ def custom_arange_kernel(C:UOp) -> UOp: return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}")) def custom_add_one_kernel(B:UOp, A:UOp) -> UOp: + A,B = A.flatten(), B.flatten() assert B.size == A.size i = UOp.range(A.size, 0) return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}")) def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp: + C,A,B = C.flatten(), A.flatten(), B.flatten() i = UOp.range(C.size, 0) return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify() def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp: + C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten() assert C.size == D.size i = UOp.range(C.size, 0) store_c = C[i].store(A[i]+B[i]) @@ -39,13 +43,22 @@ def custom_sum(B:UOp, A:UOp) -> UOp: return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=())) def flip_contract_kernel(dest:UOp, src:UOp): - assert dest.size%4 == 0 - i = UOp.range(dest.size//4, 0) - j = UOp.range(4, 1, AxisType.UPCAST) - vec = src[i*4+j].contract(j) - store = UOp.group(*[dest[i*4+k].store(vec.gep(3-k)) for k in range(4)]) + i = UOp.range(dest.shape[0], 0) + j = UOp.range(dest.shape[1], 1, AxisType.UPCAST) + vec = src[i, j].contract(j) + store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)]) return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=())) +def slice_sum_kernel(dest:UOp, src:UOp): + G = UOp.range(src.shape[0], 0, AxisType.GLOBAL) + slice_src = src[G, :] + reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG) + reg = reg.after(G)[0].set(0) + R = UOp.range(src.shape[1], 1, AxisType.REDUCE) + reg = reg[0].set(reg[0] + slice_src[R], end=R) + ast = dest[G].set(reg[0], end=G) + return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=())) + # **** backward callbacks **** def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]: @@ -111,6 +124,12 @@ class TestCustomKernel(unittest.TestCase): b = Tensor.custom_kernel(tst, a, fxn=custom_sum)[0] self.assertEqual(b.item(), 15) + def test_slice_sum(self): + A = Tensor.randn(16, 16) + B = Tensor.empty(16) + B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0] + self.assertTrue(B.allclose(A.sum(1))) + def test_gemm(self): N = 16 a = Tensor.randn(N, N) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 5d6b750cc9..b70aaf30fc 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -4,7 +4,7 @@ from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import panic # import all pattern matchers here @@ -19,13 +19,19 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize +pm_syntactic_sugar = PatternMatcher([ + # INDEX on ptr INDEX concats them + (UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True), + lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None), +]) + def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: if ren is None: ren = Renderer() if SPEC: type_verify(sink, kernel_spec) # preprocess - sink = graph_rewrite(sink, pm_mops, name="early movement ops") + sink = graph_rewrite(sink, pm_mops+pm_syntactic_sugar, name="early movement ops", bottom_up=True) # first we optimize if optimize: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 664198fbe9..d8c0476d0b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1435,11 +1435,6 @@ class Tensor(MathMixin, MovementMixin): final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) - def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: - total = self.ndim + int(extra) - if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") - return dim + total if dim < 0 else dim - def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]: """ Splits the tensor into chunks along the dimension specified by `dim`. @@ -1597,22 +1592,6 @@ class Tensor(MathMixin, MovementMixin): order[dim0], order[dim1] = order[dim1], order[dim0] return self.permute(order) - def flatten(self, start_dim=0, end_dim=-1) -> Tensor: - """ - Flattens the tensor by reshaping it into a one-dimensional tensor. - If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(8).reshape(2, 2, 2) - print(t.flatten().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.flatten(start_dim=1).numpy()) - ``` - """ - start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) - return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) - def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor: """ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. @@ -1927,6 +1906,12 @@ class Tensor(MathMixin, MovementMixin): is_nan_close = (self.isnan() & other.isnan()) & equal_nan return is_finite_close | is_infinite_close | is_nan_close + def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool: + """ + Check if all self and other are close. Return True or False. + """ + return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item()) + def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the mean value of the tensor along the specified axis or axes. @@ -4194,29 +4179,6 @@ class Tensor(MathMixin, MovementMixin): # ***** Tensor Properties ***** - @property - def ndim(self) -> int: - """ - Returns the number of dimensions in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2], [3, 4]]) - print(t.ndim) - ``` - """ - return len(self.shape) - - def numel(self) -> sint: - """ - Returns the total number of elements in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - print(t.numel()) - ``` - """ - return prod(self.shape) - def element_size(self) -> int: """ Returns the size in bytes of an individual element in the tensor. diff --git a/tinygrad/uop/mixins.py b/tinygrad/uop/mixins.py index 536a4a09ba..e2279146a6 100644 --- a/tinygrad/uop/mixins.py +++ b/tinygrad/uop/mixins.py @@ -183,6 +183,34 @@ class MovementMixin: def shape(self) -> tuple["sint", ...]: raise NotImplementedError # great functions you get! + @property + def ndim(self) -> int: + """ + Returns the number of dimensions in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2], [3, 4]]) + print(t.ndim) + ``` + """ + return len(self.shape) + + def numel(self) -> "sint": + """ + Returns the total number of elements in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + print(t.numel()) + ``` + """ + return prod(self.shape) + + def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") + return dim + total if dim < 0 else dim + def view(self, shape, *args) -> Self: """`.view` is an alias for `.reshape`.""" return self.reshape(shape, *args) @@ -204,3 +232,19 @@ class MovementMixin: if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self + + def flatten(self, start_dim=0, end_dim=-1) -> Self: + """ + Flattens the tensor by reshaping it into a one-dimensional tensor. + If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(8).reshape(2, 2, 2) + print(t.flatten().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.flatten(start_dim=1).numpy()) + ``` + """ + start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) + return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) \ No newline at end of file diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5452af7a13..6d5b42f6c1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -187,10 +187,18 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass): def _shape(self) -> tuple[sint, ...]|None: match self.op: # late ops don't have shape - case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ + case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT: return None + case Ops.INDEX: + # non pointer index doesn't have a shape + if not isinstance(self.dtype, PtrDType): return None + # fully indexed doesn't have a shape. TODO: remove this + if len(self.src[1:]) == len(self.src[0].shape): return None + # pointer index + return self.src[0].shape[len(self.src[1:]):] + # some ops init the shape case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None case Ops.BUFFER: return (self.arg,) @@ -344,7 +352,13 @@ class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass): def index(self, *srcs:UOp|None, ptr=False, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def __getitem__(self, idx): - return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)]) + idx = argfix(idx) + assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args" + if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]): + perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx)) + return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True) + else: + return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx]) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape)