mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove kernelize to prepare for refactors (#13463)
* remove kernelize to prepare for refactors * less kernelize * last test
This commit is contained in:
@@ -1,37 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop import Ops
|
||||
|
||||
class TestKernelize(unittest.TestCase):
|
||||
def test_add_reshaped(self):
|
||||
a = Tensor.ones(16,16).contiguous()
|
||||
b = Tensor.zeros(16,16).contiguous()
|
||||
ret = (a+b).sum(axis=1)
|
||||
ret_reshaped_1 = ret.reshape(4,4)
|
||||
ret_reshaped_2 = ret.reshape(2,8)
|
||||
ret.kernelize()
|
||||
self.assertIs(ret_reshaped_1.uop.src[0], ret_reshaped_2.uop.src[0])
|
||||
|
||||
def test_two_reduce(self):
|
||||
a = Tensor.ones(16,16).contiguous()
|
||||
a1 = a.sum(axis=1)
|
||||
a0 = a1.sum(axis=0)
|
||||
a0.kernelize()
|
||||
self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2)
|
||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
||||
# input Tensor and user contiguous kernelize
|
||||
self.assertIs(a0.uop.base.op, Ops.AFTER)
|
||||
self.assertIs(a.uop.base.op, Ops.AFTER)
|
||||
|
||||
def test_two_reduce_w_add(self):
|
||||
a = Tensor.ones(16,16).contiguous()
|
||||
a1 = a.sum(axis=1)
|
||||
a0 = (a1+1).sum(axis=0)
|
||||
a0.kernelize()
|
||||
# NOTE: the +1 is fused with a1, so a1 is not kernelized
|
||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
||||
# the input to the REDUCE_AXIS is an ASSIGN though
|
||||
self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,15 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
class TestSimpleSchedule(unittest.TestCase):
|
||||
def test_reduce_doesnt_split(self):
|
||||
a = Tensor.empty(16,16).sum(axis=1)
|
||||
a1 = a.reshape(4,4)
|
||||
a2 = a.reshape(16,1,1)
|
||||
Tensor.kernelize(a1, a2)
|
||||
kernels = [x for x in a1.uop.sink(a2.uop).toposort() if x.op is Ops.KERNEL]
|
||||
self.assertEqual(len(kernels), 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -3,19 +3,19 @@ from tinygrad import Tensor
|
||||
|
||||
class TestLoadStore(unittest.TestCase):
|
||||
def test_load_shape(self):
|
||||
t = Tensor(bytes(16)).fs_load(1024).kernelize()
|
||||
t = Tensor(bytes(16)).fs_load(1024)
|
||||
assert t.shape == (1024,), t.shape
|
||||
|
||||
def test_store_shape(self):
|
||||
t = Tensor.zeros(1024).fs_store().kernelize()
|
||||
t = Tensor.zeros(1024).fs_store()
|
||||
assert t.shape == (16,), t.shape
|
||||
|
||||
def test_load_large_shape(self):
|
||||
t = Tensor(bytes(16)).fs_load(10_000_000).kernelize()
|
||||
t = Tensor(bytes(16)).fs_load(10_000_000)
|
||||
assert t.shape == (10_000_000,), t.shape
|
||||
|
||||
def test_store_large_shape(self):
|
||||
t = Tensor.zeros(10_000_000).fs_store().kernelize()
|
||||
t = Tensor.zeros(10_000_000).fs_store()
|
||||
assert t.shape == (16,), t.shape
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -262,14 +262,6 @@ from tinygrad import Tensor, Device
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
class TestVizIntegration(BaseTestViz):
|
||||
# kernelize has a custom name function in VIZ
|
||||
def test_kernelize_tracing(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
Tensor.kernelize(a+1, a+2)
|
||||
lst = get_viz_list()
|
||||
self.assertEqual(len(lst), 1)
|
||||
self.assertEqual(lst[0]["name"], "Schedule 2 Kernels n1")
|
||||
|
||||
# codegen supports rendering of code blocks
|
||||
def test_codegen_tracing(self):
|
||||
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
|
||||
@@ -284,7 +276,7 @@ class TestVizIntegration(BaseTestViz):
|
||||
a = Tensor.empty(1)
|
||||
b = Tensor.empty(1)
|
||||
metadata = (alu:=a+b).uop.metadata
|
||||
alu.kernelize()
|
||||
alu.schedule()
|
||||
graph = next(get_viz_details(0, 0))["graph"]
|
||||
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user