This commit is contained in:
geohotstan
2023-07-01 16:29:35 +08:00
committed by GitHub
parent 574cbda979
commit 575f75f613
2 changed files with 4 additions and 1 deletions

View File

@@ -84,6 +84,9 @@ class TestOps(unittest.TestCase):
def test_arange(self):
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(10, 5, 3), forward_only=True)
helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(5, 10, -3), forward_only=True)
helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(5, 11, -3), forward_only=True)
def test_where(self):
helper_test_op(
[(100,)],

View File

@@ -151,7 +151,7 @@ class Tensor:
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step, **kwargs).cumsum() + (start - step)
def arange(stop, start=0, step=1, **kwargs): return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):