From d2c1e8409abb281493e7aa5f7a8efb9055f6a1cf Mon Sep 17 00:00:00 2001 From: madt2709 <55849102+madt2709@users.noreply.github.com> Date: Thu, 20 Jul 2023 21:27:23 -0700 Subject: [PATCH] Update arange to be (start, stop, step) (#1308) --- examples/yolov8.py | 4 ++-- extra/onnx_ops.py | 2 +- test/test_ops.py | 6 +++--- tinygrad/tensor.py | 6 ++++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/yolov8.py b/examples/yolov8.py index cad1338c5d..956884a592 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -175,8 +175,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): assert feats is not None for i, stride in enumerate(strides): _, _, h, w = feats[i].shape - sx = Tensor.arange(stop=w) + grid_cell_offset - sy = Tensor.arange(stop=h) + grid_cell_offset + sx = Tensor.arange(w) + grid_cell_offset + sy = Tensor.arange(h) + grid_cell_offset # this is np.meshgrid but in tinygrad sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 868f28bc26..377f340f65 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -201,7 +201,7 @@ def Tile(input, repeats): final_shape = [r*s for r,s in zip(repeats_, input.shape)] return input.reshape(new_shape).expand(expand_shape).reshape(final_shape) -def Range(start, limit, delta): return Tensor.arange(safe_numpy(limit)[0], safe_numpy(start)[0], safe_numpy(delta)[0]) +def Range(start, limit, delta): return Tensor.arange(safe_numpy(start)[0], safe_numpy(limit)[0], step=safe_numpy(delta)[0]) def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype) def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool) diff --git a/test/test_ops.py b/test/test_ops.py index 10fbd9dc6d..efb04e6dd2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -125,9 +125,9 @@ class TestOps(unittest.TestCase): def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) - helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(10, 5, 3), forward_only=True) - helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(5, 10, -3), forward_only=True) - helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(5, 11, -3), forward_only=True) + helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True) + helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True) + helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True) def test_where(self): helper_test_op( [(100,)], diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cb3a1df7d2..e5c358ca73 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -152,7 +152,9 @@ class Tensor: def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) @staticmethod - def arange(stop, start=0, step=1, **kwargs): return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) + def arange(start, stop=None, step=1, **kwargs): + if stop is None: stop, start = start, 0 + return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) @staticmethod def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): @@ -499,7 +501,7 @@ class Tensor: def tan(self): return self.sin() / self.cos() @staticmethod - def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c) + def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self)) def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)