mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove kernelize to prepare for refactors (#13463)
* remove kernelize to prepare for refactors * less kernelize * last test
This commit is contained in:
@@ -672,33 +672,6 @@ class TestSchedule(unittest.TestCase):
|
|||||||
c = (a.sum(2).contiguous() + b).contiguous()
|
c = (a.sum(2).contiguous() + b).contiguous()
|
||||||
check_schedule(c, 2)
|
check_schedule(c, 2)
|
||||||
|
|
||||||
def test_kernelize(self):
|
|
||||||
a = Tensor.empty(10)
|
|
||||||
b = Tensor.empty(10)
|
|
||||||
c = (a+b).kernelize()
|
|
||||||
d = c+2
|
|
||||||
check_schedule(d, 2)
|
|
||||||
|
|
||||||
def test_kernelize_view(self):
|
|
||||||
a = Tensor.empty(4,1)
|
|
||||||
b = a*2
|
|
||||||
c = b.kernelize()+Tensor.empty(4,4)
|
|
||||||
check_schedule(c, 2)
|
|
||||||
|
|
||||||
def test_kernelize_diamond(self):
|
|
||||||
a = Tensor([0]).realize()
|
|
||||||
prev_a = (a+1).contiguous()
|
|
||||||
a.assign(Tensor([2]))
|
|
||||||
a.kernelize(prev_a)
|
|
||||||
self.assertEqual((prev_a+a*3).item(), 1+2*3)
|
|
||||||
|
|
||||||
def test_kernelize_sym(self):
|
|
||||||
a = Tensor([1])+Tensor([2])
|
|
||||||
a.kernelize()
|
|
||||||
b = a/a
|
|
||||||
check_schedule(b, 0)
|
|
||||||
self.assertEqual(b.item(), 1)
|
|
||||||
|
|
||||||
# TODO: this requires supporting multiple stores in the AST
|
# TODO: this requires supporting multiple stores in the AST
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_multioutput_ast(self):
|
def test_multioutput_ast(self):
|
||||||
@@ -710,35 +683,6 @@ class TestSchedule(unittest.TestCase):
|
|||||||
self.assertEqual(a.buffer.numpy(), [7])
|
self.assertEqual(a.buffer.numpy(), [7])
|
||||||
self.assertEqual(b.buffer.numpy(), [12])
|
self.assertEqual(b.buffer.numpy(), [12])
|
||||||
|
|
||||||
# unlike schedule, kernelize can be called multiple times on a Tensor
|
|
||||||
def test_double_kernelize(self):
|
|
||||||
a = Tensor.empty(10)
|
|
||||||
b = Tensor.empty(10)
|
|
||||||
c = (a+b)
|
|
||||||
d = c.kernelize()+2
|
|
||||||
e = c.kernelize()+d.kernelize()
|
|
||||||
check_schedule(e, 3)
|
|
||||||
|
|
||||||
def test_kernelize_bw(self):
|
|
||||||
a = Tensor.full((3,), 2.0, requires_grad=True).contiguous()
|
|
||||||
b = Tensor.full((3,), 3.0, requires_grad=True).contiguous()
|
|
||||||
x = (a*b).kernelize()
|
|
||||||
y = Tensor.eye(3, requires_grad=True)
|
|
||||||
z = y.matmul(x).sum()
|
|
||||||
z.backward()
|
|
||||||
self.assertEqual(z.item(), 18.0)
|
|
||||||
self.assertEqual(z.grad.item(), 1.0)
|
|
||||||
|
|
||||||
def test_kernelize_bw_view(self):
|
|
||||||
a = Tensor.full((3,1), 2.0, requires_grad=True).contiguous()
|
|
||||||
b = Tensor.full((3,1), 3.0, requires_grad=True).contiguous()
|
|
||||||
x = (a*b).kernelize()
|
|
||||||
y = Tensor.eye(6, requires_grad=True)
|
|
||||||
z = y.matmul(x.expand(3,2).reshape(6)).sum()
|
|
||||||
z.backward()
|
|
||||||
self.assertEqual(z.item(), 36.0)
|
|
||||||
self.assertEqual(z.grad.item(), 1.0)
|
|
||||||
|
|
||||||
@unittest.skip("no longer supported")
|
@unittest.skip("no longer supported")
|
||||||
def test_double_from(self):
|
def test_double_from(self):
|
||||||
x = Tensor([1,2,3,4])
|
x = Tensor([1,2,3,4])
|
||||||
@@ -1915,18 +1859,6 @@ class TestSchedule(unittest.TestCase):
|
|||||||
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
|
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
|
||||||
self.assertEqual(root.item(), N * 2)
|
self.assertEqual(root.item(), N * 2)
|
||||||
|
|
||||||
def test_limit_bufs_kernelize(self):
|
|
||||||
N = 31
|
|
||||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
|
||||||
bufs = [Tensor(i).contiguous().realize() for i in range(N)]
|
|
||||||
x = bufs[0]
|
|
||||||
for y in bufs[1:]: x = x+y
|
|
||||||
x.kernelize()
|
|
||||||
kcount = len([s for s in x.uop.toposort() if s.op is Ops.KERNEL])
|
|
||||||
z = x+Tensor.empty(1) # z only loads 2 buffers
|
|
||||||
sched = z.schedule()
|
|
||||||
self.assertEqual(len(sched), kcount+1)
|
|
||||||
|
|
||||||
class TestSwizzle(unittest.TestCase):
|
class TestSwizzle(unittest.TestCase):
|
||||||
def test_swizzle_simple(self):
|
def test_swizzle_simple(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
@@ -2118,7 +2050,7 @@ class TestCopyFolding(unittest.TestCase):
|
|||||||
b = Tensor.empty(4, device="CPU")
|
b = Tensor.empty(4, device="CPU")
|
||||||
add = a+b
|
add = a+b
|
||||||
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||||
add.kernelize()
|
add.schedule()
|
||||||
|
|
||||||
def test_alu_before_copy(self):
|
def test_alu_before_copy(self):
|
||||||
buf = Tensor.ones(1).contiguous().realize()
|
buf = Tensor.ones(1).contiguous().realize()
|
||||||
@@ -2438,5 +2370,12 @@ class TestUOpBecome(unittest.TestCase):
|
|||||||
b.shrink(((0,4),)).assign(a_view).realize()
|
b.shrink(((0,4),)).assign(a_view).realize()
|
||||||
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||||
|
|
||||||
|
class TestSimpleSchedule(unittest.TestCase):
|
||||||
|
def test_reduce_doesnt_split(self):
|
||||||
|
a = Tensor.empty(16,16).sum(axis=1)
|
||||||
|
a1 = a.reshape(4,4)
|
||||||
|
a2 = a.reshape(16,1,1)
|
||||||
|
self.assertEqual(len(Tensor.schedule(a1, a2)), 1)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from tinygrad import Tensor
|
|
||||||
from tinygrad.uop import Ops
|
|
||||||
|
|
||||||
class TestKernelize(unittest.TestCase):
|
|
||||||
def test_add_reshaped(self):
|
|
||||||
a = Tensor.ones(16,16).contiguous()
|
|
||||||
b = Tensor.zeros(16,16).contiguous()
|
|
||||||
ret = (a+b).sum(axis=1)
|
|
||||||
ret_reshaped_1 = ret.reshape(4,4)
|
|
||||||
ret_reshaped_2 = ret.reshape(2,8)
|
|
||||||
ret.kernelize()
|
|
||||||
self.assertIs(ret_reshaped_1.uop.src[0], ret_reshaped_2.uop.src[0])
|
|
||||||
|
|
||||||
def test_two_reduce(self):
|
|
||||||
a = Tensor.ones(16,16).contiguous()
|
|
||||||
a1 = a.sum(axis=1)
|
|
||||||
a0 = a1.sum(axis=0)
|
|
||||||
a0.kernelize()
|
|
||||||
self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2)
|
|
||||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
|
||||||
# input Tensor and user contiguous kernelize
|
|
||||||
self.assertIs(a0.uop.base.op, Ops.AFTER)
|
|
||||||
self.assertIs(a.uop.base.op, Ops.AFTER)
|
|
||||||
|
|
||||||
def test_two_reduce_w_add(self):
|
|
||||||
a = Tensor.ones(16,16).contiguous()
|
|
||||||
a1 = a.sum(axis=1)
|
|
||||||
a0 = (a1+1).sum(axis=0)
|
|
||||||
a0.kernelize()
|
|
||||||
# NOTE: the +1 is fused with a1, so a1 is not kernelized
|
|
||||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
|
||||||
# the input to the REDUCE_AXIS is an ASSIGN though
|
|
||||||
self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from tinygrad import Tensor
|
|
||||||
from tinygrad.uop.ops import Ops
|
|
||||||
|
|
||||||
class TestSimpleSchedule(unittest.TestCase):
|
|
||||||
def test_reduce_doesnt_split(self):
|
|
||||||
a = Tensor.empty(16,16).sum(axis=1)
|
|
||||||
a1 = a.reshape(4,4)
|
|
||||||
a2 = a.reshape(16,1,1)
|
|
||||||
Tensor.kernelize(a1, a2)
|
|
||||||
kernels = [x for x in a1.uop.sink(a2.uop).toposort() if x.op is Ops.KERNEL]
|
|
||||||
self.assertEqual(len(kernels), 1)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@@ -3,19 +3,19 @@ from tinygrad import Tensor
|
|||||||
|
|
||||||
class TestLoadStore(unittest.TestCase):
|
class TestLoadStore(unittest.TestCase):
|
||||||
def test_load_shape(self):
|
def test_load_shape(self):
|
||||||
t = Tensor(bytes(16)).fs_load(1024).kernelize()
|
t = Tensor(bytes(16)).fs_load(1024)
|
||||||
assert t.shape == (1024,), t.shape
|
assert t.shape == (1024,), t.shape
|
||||||
|
|
||||||
def test_store_shape(self):
|
def test_store_shape(self):
|
||||||
t = Tensor.zeros(1024).fs_store().kernelize()
|
t = Tensor.zeros(1024).fs_store()
|
||||||
assert t.shape == (16,), t.shape
|
assert t.shape == (16,), t.shape
|
||||||
|
|
||||||
def test_load_large_shape(self):
|
def test_load_large_shape(self):
|
||||||
t = Tensor(bytes(16)).fs_load(10_000_000).kernelize()
|
t = Tensor(bytes(16)).fs_load(10_000_000)
|
||||||
assert t.shape == (10_000_000,), t.shape
|
assert t.shape == (10_000_000,), t.shape
|
||||||
|
|
||||||
def test_store_large_shape(self):
|
def test_store_large_shape(self):
|
||||||
t = Tensor.zeros(10_000_000).fs_store().kernelize()
|
t = Tensor.zeros(10_000_000).fs_store()
|
||||||
assert t.shape == (16,), t.shape
|
assert t.shape == (16,), t.shape
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -262,14 +262,6 @@ from tinygrad import Tensor, Device
|
|||||||
from tinygrad.engine.realize import get_program
|
from tinygrad.engine.realize import get_program
|
||||||
|
|
||||||
class TestVizIntegration(BaseTestViz):
|
class TestVizIntegration(BaseTestViz):
|
||||||
# kernelize has a custom name function in VIZ
|
|
||||||
def test_kernelize_tracing(self):
|
|
||||||
a = Tensor.empty(4, 4)
|
|
||||||
Tensor.kernelize(a+1, a+2)
|
|
||||||
lst = get_viz_list()
|
|
||||||
self.assertEqual(len(lst), 1)
|
|
||||||
self.assertEqual(lst[0]["name"], "Schedule 2 Kernels n1")
|
|
||||||
|
|
||||||
# codegen supports rendering of code blocks
|
# codegen supports rendering of code blocks
|
||||||
def test_codegen_tracing(self):
|
def test_codegen_tracing(self):
|
||||||
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
|
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
|
||||||
@@ -284,7 +276,7 @@ class TestVizIntegration(BaseTestViz):
|
|||||||
a = Tensor.empty(1)
|
a = Tensor.empty(1)
|
||||||
b = Tensor.empty(1)
|
b = Tensor.empty(1)
|
||||||
metadata = (alu:=a+b).uop.metadata
|
metadata = (alu:=a+b).uop.metadata
|
||||||
alu.kernelize()
|
alu.schedule()
|
||||||
graph = next(get_viz_details(0, 0))["graph"]
|
graph = next(get_viz_details(0, 0))["graph"]
|
||||||
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
|
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
|
||||||
|
|
||||||
|
|||||||
@@ -220,25 +220,6 @@ class Tensor(OpMixin):
|
|||||||
|
|
||||||
# ***** data handlers ****
|
# ***** data handlers ****
|
||||||
|
|
||||||
def kernelize(self, *lst:Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Creates the kernels and buffers needed to realize these Tensor(s).
|
|
||||||
|
|
||||||
NOTE: Kernelize can be called multiple times on a Tensor
|
|
||||||
"""
|
|
||||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
|
||||||
|
|
||||||
# verify Tensors match the spec
|
|
||||||
if SPEC: type_verify(big_sink, tensor_spec)
|
|
||||||
|
|
||||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
|
||||||
_apply_map_to_tensors(get_multi_map(big_sink), name="Apply Multi Map")
|
|
||||||
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
|
||||||
|
|
||||||
becomes_map = get_rangeify_map(big_sink)
|
|
||||||
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||||
"""
|
"""
|
||||||
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
||||||
@@ -254,7 +235,19 @@ class Tensor(OpMixin):
|
|||||||
NOTE: A Tensor can only be scheduled once.
|
NOTE: A Tensor can only be scheduled once.
|
||||||
"""
|
"""
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
self.kernelize(*lst)
|
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||||
|
|
||||||
|
# verify Tensors match the spec
|
||||||
|
if SPEC: type_verify(big_sink, tensor_spec)
|
||||||
|
|
||||||
|
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||||
|
_apply_map_to_tensors(get_multi_map(big_sink), name="Apply Multi Map")
|
||||||
|
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
||||||
|
|
||||||
|
becomes_map = get_rangeify_map(big_sink)
|
||||||
|
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||||
|
|
||||||
|
# get new sink
|
||||||
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||||
|
|
||||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||||
@@ -1853,8 +1846,7 @@ class Tensor(OpMixin):
|
|||||||
# χ and ι step
|
# χ and ι step
|
||||||
state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2))
|
state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2))
|
||||||
state = state.flatten(1) ^ rnd_const_masks[i]
|
state = state.flatten(1) ^ rnd_const_masks[i]
|
||||||
# NOTE: kernelize here to prevent internal stack from growing propotional to data size
|
# NOTE: there was a kernelize here to prevent internal stack from growing propotional to data size, do we need something else?
|
||||||
state = state.kernelize()
|
|
||||||
return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes)
|
return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes)
|
||||||
|
|
||||||
def _hash_1mb(self) -> Tensor:
|
def _hash_1mb(self) -> Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user