mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove dtype from range, it will be dtypes.index soon [pr] (#11914)
* remove dtype from range, it will be dtypes.index soon [pr] * a few more
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
N = 256
|
||||
|
||||
@@ -141,9 +142,6 @@ class TestRangeify(unittest.TestCase):
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
# contiguous + reduce can support ranges?
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
@@ -152,7 +150,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(dtypes.int, 10, -1)
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
@@ -162,7 +160,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(dtypes.int, 10, -1)
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
@@ -174,7 +172,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(dtypes.int, 3, -1)
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.contiguous(a)
|
||||
|
||||
@@ -188,7 +186,7 @@ class TestOuterworld(unittest.TestCase):
|
||||
|
||||
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
||||
|
||||
a = UOp.range(dtypes.int, 3, -1)
|
||||
a = UOp.range(3, -1)
|
||||
x = x.assign(x @ W[a])
|
||||
out = x.contiguous(a)[-1].contiguous().realize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user