mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
support same uidx in multiple shape positions (#3205)
* support same uidx in multiple shape positions * rename var * update comment * add contiguous index check to global_store too * update comment * small change * is this better? * smh * smaller change? * get rid of more changes * get rid of more changes * is this even making anything better * comment * fix test --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import unittest
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
|
||||
from tinygrad.device import Compiled, Device, Buffer
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -699,5 +699,16 @@ class TestLinearizerHelper(unittest.TestCase):
|
||||
# this behavior was just copied from before, no idea why this should be true
|
||||
assert expand_node(s1, (a, b)) == [NumNode(x + y) for x in range(b.min, b.max + 1) for y in range(a.min, a.max + 1)]
|
||||
|
||||
def test_expand_nonpresent_var(self):
|
||||
a = Variable("a", 1, 3)
|
||||
n = NumNode(3) * Variable("b", 1, 3)
|
||||
assert expand_node(n, (a,)) == [n, n, n]
|
||||
|
||||
def test_expand_idxs(self):
|
||||
uidx0 = Variable("_uidx0", 0, 6)
|
||||
uidx1 = Variable("_uidx1", 0, 1)
|
||||
idxs = (uidx0 // 5, uidx0 * 5, uidx1)
|
||||
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator, Sequence
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -25,6 +25,9 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
|
||||
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
|
||||
eidxs = [expand_idx(node) for node in nodes]
|
||||
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
|
||||
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
|
||||
yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
||||
|
||||
@@ -73,7 +76,7 @@ class Linearizer(Kernel):
|
||||
localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype)
|
||||
const = buf.val if isinstance(buf, ConstBuffer) else acc
|
||||
|
||||
expand_vars = tuple([expand_idx(idx) for j, idx in enumerate(idxs)])
|
||||
expand_vars = expand_idxs(idxs)
|
||||
|
||||
dim, amt = None, 1
|
||||
# float 4 grouping
|
||||
@@ -131,12 +134,12 @@ class Linearizer(Kernel):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
|
||||
expanded_nodes = [expand_node(idx) for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
expand_vars = expand_idxs(idxs)
|
||||
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
|
||||
store_offset = dict(zip(_idxs, store))
|
||||
|
||||
# float4 grouping
|
||||
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expanded_nodes[upcast_dim[0]]) in [2,4]:
|
||||
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
|
||||
grouped_store_offset = defaultdict(list)
|
||||
for k in store_offset:
|
||||
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
|
||||
|
||||
Reference in New Issue
Block a user