add canonicalization to View.create (#3280)

* Reapply "take merge views from corsix branch" (#3278)

This reverts commit d298916232.

* reintroduce merge views

* update second any

* isinstance -> not

* 25% less same but unequal
This commit is contained in:
George Hotz
2024-01-30 10:26:48 -08:00
committed by GitHub
parent d8f6280ffb
commit 247a8a2a6c
2 changed files with 8 additions and 4 deletions

View File

@@ -31,6 +31,7 @@ if __name__ == "__main__":
# random.seed(42)
total = getenv("CNT", 1000)
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
same_but_neq = 0
for _ in trange(total, desc=f"{fuzz}"):
st1, st2 = fuzz()
eq = st_equal(st1, st2)
@@ -38,8 +39,10 @@ if __name__ == "__main__":
print(colored("same but unequal", "yellow"))
print(st1.simplify())
print(st2.simplify())
same_but_neq += 1
if DEBUG >= 1:
print(f"EXP: {st1}")
print(f"GOT: {st2}")
print(colored("****", "green" if eq else "red"))
if not eq: exit(0)
if getenv("CHECK_NEQ"): print(f"same but unequal {(same_but_neq/total)*100:.2f}%")

View File

@@ -85,10 +85,11 @@ class View:
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
#if mask and any(elim := [isinstance(b, int) and isinstance(e, int) and b+1 >= e for b,e in mask]):
# if any(b >= e for b,e in mask): strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
# offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
# strides = tuple(0 if e else st for st,e in zip(strides, elim))
if mask and any(elim := [not (b+1 < e) for b,e in mask]):
if any(not (b < e) for b,e in mask):
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
strides = tuple(0 if e else st for st,e in zip(strides, elim))
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none