mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
real strides with uops (#6365)
* real strides with uops [run_process_replay]
* compare with old
* Revert "compare with old"
This reverts commit f53a8d4276.
* make those @unittest.expectedFailure
This commit is contained in:
@@ -22,14 +22,17 @@ class TestSymbolic(unittest.TestCase):
|
||||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_0(self):
|
||||
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (8, None))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_1(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_2(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast, Any
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, Any
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
from tinygrad.codegen.uopgraph import constant_folder, _get_add_chain
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
@@ -101,18 +101,18 @@ class ShapeTracker:
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
|
||||
bad_idx_vars: Set[Variable] = set()
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
|
||||
try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
|
||||
except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
|
||||
elif tidx not in idx_vars: ret[i] = 0
|
||||
ret: List[Optional[sint]] = [None] * len(self.shape)
|
||||
idx, valid = self.to_indexed_uops()
|
||||
idx = graph_rewrite(idx, pm=constant_folder)
|
||||
for c in _get_add_chain(idx):
|
||||
if c.op is UOps.RANGE: ret[c.arg] = 1
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
used_ranges = [x.arg for x in graph_rewrite(idx, pm=constant_folder).sparents if x.op is UOps.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
masked_axis = [x.arg for x in graph_rewrite(valid, pm=constant_folder).sparents if x.op is UOps.RANGE]
|
||||
ret = [None if i in masked_axis else x for i,x in enumerate(ret)]
|
||||
return tuple(ret)
|
||||
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
Reference in New Issue
Block a user