diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8cb727a9cb..10222984e9 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores -from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node +from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs from tinygrad.device import Compiled, Device, Buffer from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker @@ -699,5 +699,16 @@ class TestLinearizerHelper(unittest.TestCase): # this behavior was just copied from before, no idea why this should be true assert expand_node(s1, (a, b)) == [NumNode(x + y) for x in range(b.min, b.max + 1) for y in range(a.min, a.max + 1)] + def test_expand_nonpresent_var(self): + a = Variable("a", 1, 3) + n = NumNode(3) * Variable("b", 1, 3) + assert expand_node(n, (a,)) == [n, n, n] + + def test_expand_idxs(self): + uidx0 = Variable("_uidx0", 0, 6) + uidx1 = Variable("_uidx1", 0, 1) + idxs = (uidx0 // 5, uidx0 * 5, uidx1) + assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 68b80f02ac..6bb4b55fa5 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator +from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator, Sequence import itertools, math, functools from collections import defaultdict @@ -25,6 +25,9 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)] def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0)) +def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]: + eidxs = [expand_idx(node) for node in nodes] + return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]: yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]])) @@ -73,7 +76,7 @@ class Linearizer(Kernel): localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype) const = buf.val if isinstance(buf, ConstBuffer) else acc - expand_vars = tuple([expand_idx(idx) for j, idx in enumerate(idxs)]) + expand_vars = expand_idxs(idxs) dim, amt = None, 1 # float 4 grouping @@ -131,12 +134,12 @@ class Linearizer(Kernel): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" - expanded_nodes = [expand_node(idx) for idx in idxs] - _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] + expand_vars = expand_idxs(idxs) + _idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose store_offset = dict(zip(_idxs, store)) # float4 grouping - if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expanded_nodes[upcast_dim[0]]) in [2,4]: + if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]: grouped_store_offset = defaultdict(list) for k in store_offset: _idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]