mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
add canonicalization to View.create (#3280)
* Reapply "take merge views from corsix branch" (#3278)
This reverts commit d298916232.
* reintroduce merge views
* update second any
* isinstance -> not
* 25% less same but unequal
This commit is contained in:
3
test/external/fuzz_shapetracker_math.py
vendored
3
test/external/fuzz_shapetracker_math.py
vendored
@@ -31,6 +31,7 @@ if __name__ == "__main__":
|
||||
# random.seed(42)
|
||||
total = getenv("CNT", 1000)
|
||||
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
|
||||
same_but_neq = 0
|
||||
for _ in trange(total, desc=f"{fuzz}"):
|
||||
st1, st2 = fuzz()
|
||||
eq = st_equal(st1, st2)
|
||||
@@ -38,8 +39,10 @@ if __name__ == "__main__":
|
||||
print(colored("same but unequal", "yellow"))
|
||||
print(st1.simplify())
|
||||
print(st2.simplify())
|
||||
same_but_neq += 1
|
||||
if DEBUG >= 1:
|
||||
print(f"EXP: {st1}")
|
||||
print(f"GOT: {st2}")
|
||||
print(colored("****", "green" if eq else "red"))
|
||||
if not eq: exit(0)
|
||||
if getenv("CHECK_NEQ"): print(f"same but unequal {(same_but_neq/total)*100:.2f}%")
|
||||
|
||||
@@ -85,10 +85,11 @@ class View:
|
||||
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
||||
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
||||
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
||||
#if mask and any(elim := [isinstance(b, int) and isinstance(e, int) and b+1 >= e for b,e in mask]):
|
||||
# if any(b >= e for b,e in mask): strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
||||
# offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
||||
# strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
||||
if mask and any(elim := [not (b+1 < e) for b,e in mask]):
|
||||
if any(not (b < e) for b,e in mask):
|
||||
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
||||
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
||||
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
|
||||
Reference in New Issue
Block a user