patch merge_views (#3311)

This commit is contained in:
Jyotirmaya Mahanta
2024-02-12 02:53:55 -08:00
committed by GitHub
parent b6a2600c86
commit d55f99e881

View File

@@ -18,8 +18,9 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm2.contiguous: return vm1
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()):
return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask)
if vm1.mask: