Fix division by zero and mask bug in add views (#11088)

* merge view infinite loop test

* adjust condition in `x//d -> x//(-d)*-1`

* Fix division by zero in add views

* adjust offset end

* fix typo in comment

* add target to test_merge_views_variable

* fix view incorrectly being masked

* ssimplify strides and offset of the new view to canonicalize

* remove print in test

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
Sieds Lykles
2025-07-07 19:05:47 +02:00
committed by GitHub
parent 71377cd233
commit 584fd6af5a
2 changed files with 20 additions and 6 deletions

View File

@@ -185,6 +185,17 @@ class TestMergeViews(unittest.TestCase):
self.assertIsNotNone(v)
self.assertEqual(v, target)
def test_merge_views_variable(self):
from tinygrad import Variable
N = 100
start_pos = Variable("start_pos", 1, N-1)
v0 = View(shape=(N, 32, 2), strides=(32, 1, 0), offset=0, mask=((0, N), (0, 32), (0, 1)), contiguous=False)
v1 = View(shape=(1, 8, 1, 32), strides=(0, 0, 0, 2), offset=start_pos*64, mask=None, contiguous=False)
target = View(shape=(1, 8, 1, 32), strides=(0,0,0,1), offset=start_pos*32, mask=None, contiguous=False)
v = v0 + v1
self.assertIsNotNone(v)
self.assertEqual(v, target)
def test_view_padded_area1(self):
# test_multinomial
v0 = View(shape=(2,), strides=(0,), offset=0, mask=((1, 2),), contiguous=False)

View File

@@ -192,15 +192,15 @@ class View:
return None
# Project vm1's offset and strides on to vm2.
origin = unravel(vm2.shape, vm1.offset)
origin = [ssimplify(o) for o in unravel(vm2.shape, vm1.offset)]
terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
strides: list[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
if not resolve((s1 := s1 - o)!=0): continue # if s1 can possibly be 0
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]
strides[d1] += ssimplify(s1 * vm2.strides[d2])
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
@@ -223,9 +223,12 @@ class View:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
if resolve(b <= t.vmin and t.vmax < e, False): continue
if resolve(b <= (t := t.simplify()).vmin and t.vmax < e, False): continue
if len(term) != 1:
if not term and newe: newe[0] = 0
if not term and newe:
# t should be a constant if no terms contribute to this dimension, but it might not be simplified
if t.vmin != t.vmax: return None
newe[0] = 0
else: bad = True
continue
d1, s1 = term[0]
@@ -238,7 +241,7 @@ class View:
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
return View.create(vm1.shape, tuple(strides), ssimplify(sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset))
@functools.cache # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]: