mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete more unused ShapeTracker stuff (#12536)
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, CI, mv_address
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.tensor import Tensor, get_shape
|
||||
from tinygrad.shape.view import get_contraction, get_contraction_with_reduce
|
||||
from tinygrad.shape.view import get_contraction
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
@@ -219,20 +219,6 @@ class TestMemoryview(unittest.TestCase):
|
||||
print(f"from_mv vs mv_address: {fmv_us:8.3f} µs vs {mva_us:8.3f} µs")
|
||||
|
||||
class TestGetContraction(unittest.TestCase):
|
||||
def test_contraction_with_reduce(self):
|
||||
r = get_contraction((16, 1, 1, 1), (16, 1, 1))
|
||||
self.assertEqual(r, [[0], [], [1, 2, 3]])
|
||||
r = get_contraction_with_reduce((16, 1, 1, 1), (16, 1, 1), (1,))
|
||||
self.assertEqual(r, [[0], [1, 2], [3]])
|
||||
|
||||
r = get_contraction((16, 1, 1, 1, 1), (16, 1, 1, 1))
|
||||
self.assertEqual(r, [[0], [], [], [1, 2, 3, 4]])
|
||||
r = get_contraction_with_reduce((16, 1, 1, 1, 1), (16, 1, 1, 1), (1,))
|
||||
self.assertEqual(r, [[0], [1, 2], [3], [4]])
|
||||
|
||||
r = get_contraction_with_reduce((2, 512, 1, 1), (2, 1, 512), (1,))
|
||||
self.assertIsNone(r)
|
||||
|
||||
def test_contraction(self):
|
||||
r = get_contraction((1,2,3,4), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3]])
|
||||
|
||||
@@ -72,7 +72,6 @@ class ShapeTracker:
|
||||
|
||||
def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]:
|
||||
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import cast, Sequence
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
from tinygrad.helpers import prod, all_int, flatten, ceildiv
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
@@ -13,24 +13,6 @@ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> l
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
|
||||
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
|
||||
# contraction returns the 1s as right justified as possible
|
||||
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
|
||||
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
|
||||
for i in range(len(contraction)):
|
||||
if i in reduce_axis and len(contraction[i]) == 0:
|
||||
take_from = i+1
|
||||
while take_from < len(contraction) and len(contraction[take_from]) == 0:
|
||||
assert new_shape[take_from] == 1
|
||||
take_from += 1
|
||||
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
|
||||
for j in range(take_from, i, -1):
|
||||
assert len(contraction[j]) > 0
|
||||
contraction[j-1] = contraction[j][:-1]
|
||||
contraction[j] = contraction[j][-1:]
|
||||
return contraction
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
||||
@@ -244,13 +226,6 @@ class View:
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), ssimplify(sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset))
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def invert(self, out_shape:tuple[sint, ...]) -> View|None:
|
||||
ret = View.create(self.shape)
|
||||
if self.mask: ret = ret.shrink(self.mask)
|
||||
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
||||
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def minify(self):
|
||||
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
||||
|
||||
Reference in New Issue
Block a user