support same uidx in multiple shape positions (#3205)

* support same uidx in multiple shape positions

* rename var

* update comment

* add contiguous index check to global_store too

* update comment

* small change

* is this better?

* smh

* smaller change?

* get rid of more changes

* get rid of more changes

* is this even making anything better

* comment

* fix test

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
David Hou
2024-02-21 10:37:03 -08:00
committed by GitHub
parent 1eb24af63b
commit f513c37e64
2 changed files with 20 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ import numpy as np
import unittest
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
@@ -699,5 +699,16 @@ class TestLinearizerHelper(unittest.TestCase):
# this behavior was just copied from before, no idea why this should be true
assert expand_node(s1, (a, b)) == [NumNode(x + y) for x in range(b.min, b.max + 1) for y in range(a.min, a.max + 1)]
def test_expand_nonpresent_var(self):
a = Variable("a", 1, 3)
n = NumNode(3) * Variable("b", 1, 3)
assert expand_node(n, (a,)) == [n, n, n]
def test_expand_idxs(self):
uidx0 = Variable("_uidx0", 0, 6)
uidx1 = Variable("_uidx1", 0, 1)
idxs = (uidx0 // 5, uidx0 * 5, uidx1)
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator, Sequence
import itertools, math, functools
from collections import defaultdict
@@ -25,6 +25,9 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
eidxs = [expand_idx(node) for node in nodes]
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
@@ -73,7 +76,7 @@ class Linearizer(Kernel):
localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype)
const = buf.val if isinstance(buf, ConstBuffer) else acc
expand_vars = tuple([expand_idx(idx) for j, idx in enumerate(idxs)])
expand_vars = expand_idxs(idxs)
dim, amt = None, 1
# float 4 grouping
@@ -131,12 +134,12 @@ class Linearizer(Kernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
expanded_nodes = [expand_node(idx) for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
expand_vars = expand_idxs(idxs)
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
store_offset = dict(zip(_idxs, store))
# float4 grouping
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expanded_nodes[upcast_dim[0]]) in [2,4]:
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
grouped_store_offset = defaultdict(list)
for k in store_offset:
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]