mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
limit real_size to the size of first View of ShapeTracker (#8628)
* fix real_size * add fuzzer; typing * spacing --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
13
test/external/fuzz_shapetracker_size.py
vendored
Normal file
13
test/external/fuzz_shapetracker_size.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from test.external.fuzz_shapetracker import shapetracker_ops as st_ops
|
||||
from test.unit.test_shapetracker_math import MultiShapeTracker
|
||||
from tinygrad.helpers import getenv
|
||||
import random
|
||||
|
||||
random.seed(getenv("SEED", 42))
|
||||
for i in range(getenv("CNT", 2000)):
|
||||
if getenv("DEBUG", 0) >= 1: print()
|
||||
N = random.randint(1, 10000)
|
||||
mst = MultiShapeTracker([ShapeTracker.from_shape((N,))]) # st_ops don't mutate regular shapetrackers for some reason
|
||||
for j in range(20): random.choice(st_ops)(mst)
|
||||
assert mst.sts[0].real_size() <= N, f"{N=}, real_size={mst.sts[0].real_size()}, st={mst.sts[0]}"
|
||||
@@ -833,6 +833,30 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
|
||||
self.assertEqual(st.real_size(), 8389632)
|
||||
|
||||
def test_pad_size_simple(self):
|
||||
st = ShapeTracker.from_shape((10,)).pad(((2,4),))
|
||||
self.assertEqual(st.real_size(), 10)
|
||||
|
||||
def test_pad_size_multiview(self):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
# TODO improve real_size accuracy in cases like this?
|
||||
@unittest.expectedFailure
|
||||
def test_stride_size(self):
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertEqual(st1.real_size(), 78)
|
||||
self.assertEqual(st2.real_size(), 65)
|
||||
|
||||
def test_stride_size_bounds(self):
|
||||
# lower bound checks that real_size doesn't give false positive for fitting in a buffer
|
||||
# upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,))
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertTrue(78 <= st1.real_size() <= 100)
|
||||
self.assertTrue(65 <= st2.real_size() <= 100)
|
||||
|
||||
class TestConsecutive(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
||||
@@ -72,12 +72,10 @@ class ShapeTracker:
|
||||
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
||||
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
|
||||
# upper bound on buffer size required to fit this shapetracker
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
idx, valid = self.to_indexed_uops()
|
||||
if not valid.vmax: return 0
|
||||
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
||||
return int(idx.vmax+1)
|
||||
return int((v.shrink(v.mask) if (v:=self.views[0]).mask else v).to_indexed_uops()[0].vmax + 1)
|
||||
|
||||
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user