mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Fix division by zero and mask bug in add views (#11088)
* merge view infinite loop test * adjust condition in `x//d -> x//(-d)*-1` * Fix division by zero in add views * adjust offset end * fix typo in comment * add target to test_merge_views_variable * fix view incorrectly being masked * ssimplify strides and offset of the new view to canonicalize * remove print in test --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -185,6 +185,17 @@ class TestMergeViews(unittest.TestCase):
|
||||
self.assertIsNotNone(v)
|
||||
self.assertEqual(v, target)
|
||||
|
||||
def test_merge_views_variable(self):
|
||||
from tinygrad import Variable
|
||||
N = 100
|
||||
start_pos = Variable("start_pos", 1, N-1)
|
||||
v0 = View(shape=(N, 32, 2), strides=(32, 1, 0), offset=0, mask=((0, N), (0, 32), (0, 1)), contiguous=False)
|
||||
v1 = View(shape=(1, 8, 1, 32), strides=(0, 0, 0, 2), offset=start_pos*64, mask=None, contiguous=False)
|
||||
target = View(shape=(1, 8, 1, 32), strides=(0,0,0,1), offset=start_pos*32, mask=None, contiguous=False)
|
||||
v = v0 + v1
|
||||
self.assertIsNotNone(v)
|
||||
self.assertEqual(v, target)
|
||||
|
||||
def test_view_padded_area1(self):
|
||||
# test_multinomial
|
||||
v0 = View(shape=(2,), strides=(0,), offset=0, mask=((1, 2),), contiguous=False)
|
||||
|
||||
@@ -192,15 +192,15 @@ class View:
|
||||
return None
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = unravel(vm2.shape, vm1.offset)
|
||||
origin = [ssimplify(o) for o in unravel(vm2.shape, vm1.offset)]
|
||||
terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
|
||||
strides: list[sint] = [0] * len(vm1.shape)
|
||||
for d1, st in enumerate(vm1.strides):
|
||||
if st == 0: continue
|
||||
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
|
||||
if (s1 := s1 - o) == 0: continue
|
||||
if not resolve((s1 := s1 - o)!=0): continue # if s1 can possibly be 0
|
||||
terms[d2].append((d1, s1))
|
||||
strides[d1] += s1 * vm2.strides[d2]
|
||||
strides[d1] += ssimplify(s1 * vm2.strides[d2])
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
@@ -223,9 +223,12 @@ class View:
|
||||
# Try to project vm2's mask on to vm1.
|
||||
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
||||
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
|
||||
if resolve(b <= t.vmin and t.vmax < e, False): continue
|
||||
if resolve(b <= (t := t.simplify()).vmin and t.vmax < e, False): continue
|
||||
if len(term) != 1:
|
||||
if not term and newe: newe[0] = 0
|
||||
if not term and newe:
|
||||
# t should be a constant if no terms contribute to this dimension, but it might not be simplified
|
||||
if t.vmin != t.vmax: return None
|
||||
newe[0] = 0
|
||||
else: bad = True
|
||||
continue
|
||||
d1, s1 = term[0]
|
||||
@@ -238,7 +241,7 @@ class View:
|
||||
# Otherwise if vm2's mask was violated, then cannot merge.
|
||||
if bad: return None
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
||||
return View.create(vm1.shape, tuple(strides), ssimplify(sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset))
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
|
||||
|
||||
Reference in New Issue
Block a user