From a6e32114443514bc67a0012dcd855b351ada4658 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 Aug 2024 20:17:23 -0700 Subject: [PATCH] use vmax for real_size [run_process_replay] (#6120) * use vmax for real_size [run_process_replay] * axis is masked --- test/unit/test_shapetracker.py | 7 +++++++ tinygrad/ops.py | 2 +- tinygrad/shape/shapetracker.py | 18 ++++++++---------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index ed0b6adee6..71a708f15b 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -733,6 +733,13 @@ class TestShapeTracker(unittest.TestCase): self.test_expand() self.test_permute() + def test_axis_is_masked(self): + st = ShapeTracker.from_shape((100, 100, 100, 100)).pad(((0,1),(0,0),(2,0), (0,0))) + assert st.axis_is_masked(0) + assert not st.axis_is_masked(1) + assert st.axis_is_masked(2) + assert not st.axis_is_masked(3) + class TestShapeTrackerSize(unittest.TestCase): def test_simple_size(self): st = ShapeTracker.from_shape((100, 100)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 613b8673c0..6301ed6b0b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -193,7 +193,7 @@ class UOp: def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None - if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None + if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None if self.op is UOps.CONST: return self, self diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 2dc65a04fd..69869baec8 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -66,7 +66,9 @@ class ShapeTracker: def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self) - def to_indexed_uops(self, idxs:List[UOp]) -> Tuple[UOp, UOp]: + def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]: + idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \ + if _idxs is None else _idxs idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True)) for view in reversed(self.views[0:-1]): view = view.minify() @@ -80,13 +82,9 @@ class ShapeTracker: def real_size(self) -> int: if 0 in self.shape: return 0 - idx, valid = self.expr_idxs() - if not valid: return 0 - # TODO: it's possible that the real_size is smaller condition on valid being true - ret = idx.max - if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max - assert isinstance(ret, int), f"ret must be integer, {ret=} isn't" - return ret+1 + idx, valid = self.to_indexed_uops() + if not valid.vmax.arg: return 0 + return idx.vmax.arg+1 def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views]) @@ -132,8 +130,8 @@ class ShapeTracker: return idx, valid def axis_is_masked(self, axis:int) -> bool: - _, valid = self.expr_idxs() - return f'idx{axis}' in [v.expr for v in valid.vars()] + _, valid = self.to_indexed_uops() + return axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: