remove expr_idxs [run_process_replay] (#6567)

* remove expr_idxs [run_process_replay]

* goodbye that test
This commit is contained in:
George Hotz
2024-09-17 18:34:51 +08:00
committed by GitHub
parent 9ebbedc37f
commit 67a03e72bb
6 changed files with 18 additions and 81 deletions

View File

@@ -51,10 +51,6 @@ class TestConvShapetracker(unittest.TestCase):
print(i, i1, i2, si.inputs[0].size, i1==i2)
#self.assertEqual(i1, i2)
for stt in [st, test_st]:
s,va = stt.expr_idxs()
print(s)
print(va)
with self.assertRaises(AssertionError):
assert len(st.views) <= 2

View File

@@ -10,18 +10,6 @@ class TestSymbolic(unittest.TestCase):
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
def test_expr_idxs(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((x*3)+y)"
assert e2.render() == "1"
st = st.permute((1, 0))
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
@unittest.expectedFailure
def test_real_strides_0(self):
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
@@ -230,22 +218,6 @@ class TestSymbolicPad(unittest.TestCase):
assert t.shape == (9,)
st = t.lazydata.st
print(st)
# TODO: fix this, required for symbolic arange
with self.assertRaises(RuntimeError):
st.expr_idxs()
class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self):
# taken from symbolic shape llama
i = Variable("i", 1, 120)
gidx0 = Variable("gidx0", 0, i)
lidx1 = Variable("lidx1", 0, 7)
idx = (gidx0, lidx1, NumNode(1))
shape = (i+1, 8, 4)
strides = (1, (i*4)+4, i+1)
st = ShapeTracker((View.create(shape, strides), ))
idx, _valid = st.expr_idxs(idx)
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
if __name__ == '__main__':
unittest.main()

View File

@@ -1,16 +1,19 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.helpers import prod, DEBUG
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.codegen.uopgraph import constant_folder
from itertools import product
def shapetracker_getitem(st, val):
_locals = {"idx0": val, "valid": 1}
idx, valid = st.reshape((st.size,)).expr_idxs()
exec(f"valid={valid.render()};idx0={idx.render()}", None, _locals)
return _locals["idx0"] if _locals["valid"] else -1
def shapetracker_getitem(st:ShapeTracker, val:int):
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)])
idx, valid = graph_rewrite(idx, constant_folder), graph_rewrite(valid, constant_folder)
assert idx.op is UOps.CONST and valid.op is UOps.CONST
return idx.arg, valid.arg
class CheckingShapeTracker:
def __init__(self, shape):
@@ -70,10 +73,8 @@ class CheckingShapeTracker:
def contiguous(self): return self.st.contiguous
def assert_same(self):
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
y = [self[i] for i in range(prod(self.shape))]
idx, valid = self.st.expr_idxs()
if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st)
assert self.st.shape == self.shape
assert x == y, f"mismatch shapetracker:{x} real:{y}"
@@ -163,7 +164,6 @@ class TestIndexExpressions2d(unittest.TestCase):
def tearDown(self):
for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs):
numel = prod(shape)
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
@@ -171,7 +171,6 @@ class TestIndexExpressions2d(unittest.TestCase):
(st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
self.check_bounds(idxs_expr(idxs), offset, numel)
def default_idx(self, shape):
@@ -786,14 +785,6 @@ class TestShapeTrackerSize(unittest.TestCase):
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
self.assertEqual(st.real_size(), 8389632)
class TestIdxs(unittest.TestCase):
def test_check_idx_range(self):
# generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
# TODO: use int64
st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
with self.assertRaises(AssertionError):
st.expr_idxs()
class TestConsecutive(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@@ -3,7 +3,8 @@ from typing import List
from tinygrad.helpers import prod
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.shape.symbolic import Variable
from test.unit.test_shapetracker import shapetracker_getitem
class MultiShapeTracker:
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
@@ -19,14 +20,9 @@ class MultiShapeTracker:
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
if st1.shape != st2.shape: return False
if st1 == st2: return True
idx = Variable("idx", 0, prod(st1.shape)-1)
st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx])
st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx])
for i in range(idx.min, idx.max + 1):
st1_off = sym_infer(st1_idx, {idx: i})
st2_off = sym_infer(st2_idx, {idx: i})
st1_v = sym_infer(st1_valid, {idx: i})
st2_v = sym_infer(st2_valid, {idx: i})
for i in range(0, prod(st1.shape)):
st1_off, st1_v = shapetracker_getitem(st1, i)
st2_off, st2_v = shapetracker_getitem(st2, i)
if st1_v != st2_v or (st1_off != st2_off and st1_v):
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
print(st1)