diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c48585724..f82d878d49 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -130,8 +130,6 @@ jobs: run: | PYTHONPATH="." python test/external/fuzz_shapetracker.py PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - - name: Test to_movement_ops - run: PYTHONPATH="." python extra/to_movement_ops.py - name: Use as an external package run: | mkdir $HOME/test_external_dir diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 2a7bb85ced..cb186a6d9b 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -51,10 +51,6 @@ class TestConvShapetracker(unittest.TestCase): print(i, i1, i2, si.inputs[0].size, i1==i2) #self.assertEqual(i1, i2) - for stt in [st, test_st]: - s,va = stt.expr_idxs() - print(s) - print(va) with self.assertRaises(AssertionError): assert len(st.views) <= 2 diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index afbcff30a3..3112e6fcb2 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -10,18 +10,6 @@ class TestSymbolic(unittest.TestCase): assert st.shape == (x, 3) assert st.real_strides() == (3, 1) - def test_expr_idxs(self): - x = Variable("x", 1, 100) - st = ShapeTracker.from_shape((x, 3)) - idxs = [Variable("x", 0, 100), Variable("y", 0, 100)] - e1, e2 = st.expr_idxs(idxs) - assert e1.render() == "((x*3)+y)" - assert e2.render() == "1" - st = st.permute((1, 0)) - e1, e2 = st.expr_idxs(idxs) - assert e1.render() == "((y*3)+x)" - assert e2.render() == "1" - @unittest.expectedFailure def test_real_strides_0(self): st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 @@ -230,22 +218,6 @@ class TestSymbolicPad(unittest.TestCase): assert t.shape == (9,) st = t.lazydata.st print(st) - # TODO: fix this, required for symbolic arange - with self.assertRaises(RuntimeError): - st.expr_idxs() - -class TestSymbolicShapeExpr(unittest.TestCase): - def test_symbolic_expr_idxs(self): - # taken from symbolic shape llama - i = Variable("i", 1, 120) - gidx0 = Variable("gidx0", 0, i) - lidx1 = Variable("lidx1", 0, 7) - idx = (gidx0, lidx1, NumNode(1)) - shape = (i+1, 8, 4) - strides = (1, (i*4)+4, i+1) - st = ShapeTracker((View.create(shape, strides), )) - idx, _valid = st.expr_idxs(idx) - assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)" if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index a58bd5f61b..27f897b621 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -1,16 +1,19 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.helpers import prod, DEBUG +from tinygrad.dtype import dtypes +from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad.ops import UOp, UOps, graph_rewrite +from tinygrad.codegen.uopgraph import constant_folder from itertools import product -def shapetracker_getitem(st, val): - _locals = {"idx0": val, "valid": 1} - idx, valid = st.reshape((st.size,)).expr_idxs() - exec(f"valid={valid.render()};idx0={idx.render()}", None, _locals) - return _locals["idx0"] if _locals["valid"] else -1 +def shapetracker_getitem(st:ShapeTracker, val:int): + idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)]) + idx, valid = graph_rewrite(idx, constant_folder), graph_rewrite(valid, constant_folder) + assert idx.op is UOps.CONST and valid.op is UOps.CONST + return idx.arg, valid.arg class CheckingShapeTracker: def __init__(self, shape): @@ -70,10 +73,8 @@ class CheckingShapeTracker: def contiguous(self): return self.st.contiguous def assert_same(self): - x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))] + x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))] y = [self[i] for i in range(prod(self.shape))] - idx, valid = self.st.expr_idxs() - if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st) assert self.st.shape == self.shape assert x == y, f"mismatch shapetracker:{x} real:{y}" @@ -163,7 +164,6 @@ class TestIndexExpressions2d(unittest.TestCase): def tearDown(self): for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs): numel = prod(shape) - assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0] self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel) idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)] idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)] @@ -171,7 +171,6 @@ class TestIndexExpressions2d(unittest.TestCase): (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s] for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s): idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None] - assert idxs_expr(idxs) == st.expr_idxs(idxs)[0] self.check_bounds(idxs_expr(idxs), offset, numel) def default_idx(self, shape): @@ -786,14 +785,6 @@ class TestShapeTrackerSize(unittest.TestCase): strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False))) self.assertEqual(st.real_size(), 8389632) -class TestIdxs(unittest.TestCase): - def test_check_idx_range(self): - # generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize() - # TODO: use int64 - st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),)) - with self.assertRaises(AssertionError): - st.expr_idxs() - class TestConsecutive(unittest.TestCase): @classmethod def setUpClass(self): diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index 5d9be0523b..cf9a7561a4 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -3,7 +3,8 @@ from typing import List from tinygrad.helpers import prod from tinygrad.shape.view import View from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable, sym_infer +from tinygrad.shape.symbolic import Variable +from test.unit.test_shapetracker import shapetracker_getitem class MultiShapeTracker: def __init__(self, sts:List[ShapeTracker]): self.sts = sts @@ -19,14 +20,9 @@ class MultiShapeTracker: def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool: if st1.shape != st2.shape: return False if st1 == st2: return True - idx = Variable("idx", 0, prod(st1.shape)-1) - st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx]) - st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx]) - for i in range(idx.min, idx.max + 1): - st1_off = sym_infer(st1_idx, {idx: i}) - st2_off = sym_infer(st2_idx, {idx: i}) - st1_v = sym_infer(st1_valid, {idx: i}) - st2_v = sym_infer(st2_valid, {idx: i}) + for i in range(0, prod(st1.shape)): + st1_off, st1_v = shapetracker_getitem(st1, i) + st2_off, st2_v = shapetracker_getitem(st2, i) if st1_v != st2_v or (st1_off != st2_off and st1_v): print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}") print(st1) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 9f57c6c528..52d2a50f34 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -2,13 +2,12 @@ from __future__ import annotations import functools from dataclasses import dataclass -from typing import Tuple, List, Optional, Dict, Set, Iterable, Any +from typing import Tuple, List, Optional, Dict, Set, Any from tinygrad.helpers import merge_dicts, getenv -from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint +from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps -from tinygrad.ops import graph_rewrite +from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite from tinygrad.codegen.uopgraph import constant_folder, _get_chain # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps @@ -117,21 +116,6 @@ class ShapeTracker: def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] - def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]: - idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs) - idx, valid = self.views[-1].expr(idxs) - for view in reversed(self.views[0:-1]): - if valid.max == 0: return NumNode(-1), valid - view = view.minify() - acc, idxs = 1, [] - for d in reversed(view.shape): - idxs.append((idx//acc)%d) - acc *= d - idx, valid = view.expr(idxs[::-1], valid) - assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}" - assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}" - return idx, valid - def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() return axis in [x.arg for x in graph_rewrite(valid, constant_folder).sparents if x.op is UOps.RANGE]