mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
minor changes to views add [pr] (#8279)
naming / style / comments before logic change
This commit is contained in:
@@ -76,13 +76,13 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
|
||||
|
||||
return tuple(reversed(new_mask))
|
||||
|
||||
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
result = []
|
||||
def un1d(shape:Tuple[sint, ...], offset:sint) -> List[sint]:
|
||||
# find the position of offset on each dimension based on shape
|
||||
ret = []
|
||||
for stride in strides_for_shape(shape):
|
||||
here = offs // stride if stride != 0 else 0
|
||||
result.append(here)
|
||||
offs -= here * stride
|
||||
return result
|
||||
ret.append(offset // stride if stride != 0 else 0)
|
||||
offset -= ret[-1] * stride
|
||||
return ret
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
@@ -123,7 +123,7 @@ class View:
|
||||
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
# canonicalize 0 in shape
|
||||
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
||||
# canonicalize empty mask
|
||||
# canonicalize no-op mask
|
||||
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
||||
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
||||
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
||||
@@ -167,12 +167,12 @@ class View:
|
||||
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
||||
if vm1.mask:
|
||||
for b,e in vm1.mask:
|
||||
if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
if not resolve(b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = un1d(vm2.shape, vm1.offset)
|
||||
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
||||
terms: List[List[Tuple[int, sint]]] = [[] for _ in vm2.shape]
|
||||
strides: List[sint] = [0] * len(vm1.shape)
|
||||
for d1, st in enumerate(vm1.strides):
|
||||
if st == 0: continue
|
||||
@@ -195,8 +195,7 @@ class View:
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
if resolve(merged_term != 0): return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
reshaped_vm2 = vm2.reshape(vm2_shape)
|
||||
if reshaped_vm2 is None: return None
|
||||
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
|
||||
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
|
||||
|
||||
if vm2.mask:
|
||||
@@ -212,7 +211,7 @@ class View:
|
||||
else: bad = True
|
||||
continue
|
||||
d1, s1 = term[0]
|
||||
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
||||
if not all_int([s1, newe[d1]]):
|
||||
bad = True
|
||||
continue
|
||||
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
||||
|
||||
Reference in New Issue
Block a user