mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove duplicated canonicalize mask in __unsafe_resize [pr] (#8296)
also more comments and type annotation
This commit is contained in:
@@ -135,7 +135,7 @@ class View:
|
||||
# simplify as we go
|
||||
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
|
||||
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
|
||||
# TODO: enabling stride simplification breaks it
|
||||
# TODO: enabling stride simplification breaks symbolic jit
|
||||
"""
|
||||
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
||||
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
||||
@@ -152,7 +152,7 @@ class View:
|
||||
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
||||
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
||||
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
||||
def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
||||
def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
||||
new_shape = tuple(map(substitute, self.shape))
|
||||
new_strides = tuple(map(substitute, self.strides))
|
||||
new_offset = substitute(self.offset)
|
||||
@@ -168,7 +168,9 @@ class View:
|
||||
if vm1.mask:
|
||||
for b,e in vm1.mask:
|
||||
if not resolve(b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
if (merged := vm2 + vm1.shrink(vm1.mask)) is None: return None
|
||||
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
if not all_int(vm1.shape): return None
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = un1d(vm2.shape, vm1.offset)
|
||||
@@ -183,7 +185,6 @@ class View:
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
if not all_int(vm1.shape): return None
|
||||
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
extents: List[Tuple[sint, UOp]] = []
|
||||
@@ -196,6 +197,7 @@ class View:
|
||||
if resolve(merged_term != 0): return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
|
||||
# NOTE: this != to prevent infinite loop
|
||||
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
|
||||
|
||||
if vm2.mask:
|
||||
@@ -246,7 +248,6 @@ class View:
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
||||
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
|
||||
Reference in New Issue
Block a user