mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
hotfix: test tensor dims start at 1
This commit is contained in:
@@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_overflow_sym(self):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
def test_regular(self):
|
||||
self.do_op_then_assert(dtypes.int, 64, 64, 64)
|
||||
|
||||
def test_regular_sym(self):
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32))
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
|
||||
|
||||
@unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64")
|
||||
def test_symfold(self):
|
||||
@@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow(self):
|
||||
|
||||
Reference in New Issue
Block a user