mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
test schedule of LazyBuffers [run_process_replay] (#5859)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
@@ -14,19 +14,16 @@ from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
from tinygrad.function import Function
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if isinstance(t, Tensor): t = [t]
|
||||
seen = set()
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if isinstance(t, Tensor): outs = t.lazydata.lbs
|
||||
elif isinstance(t, List): outs = flatten([r.lazydata.lbs for r in t])
|
||||
else: outs = [t]
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.schedule(seen=seen.copy()):
|
||||
for i,out in enumerate(s.outputs):
|
||||
seen.add(out)
|
||||
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
|
||||
for pre in to_prerealize: pre.schedule()
|
||||
sched = create_schedule(outs)
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
@@ -1250,13 +1247,17 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
def test_bitcast_subbufer(self):
|
||||
a = Tensor.empty(1, dtype=dtypes.float32).realize()
|
||||
b = CycleBitcast.apply(a)
|
||||
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view=True)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
|
||||
b = a.e(BinaryOps.ADD, b)
|
||||
check_schedule(b, 2) # this should fuse when it makes sense
|
||||
|
||||
def test_bitcast_disable_subbufer(self):
|
||||
a = Tensor.empty(1, dtype=dtypes.float32).realize()
|
||||
b = CycleBitcast.apply(a, allow_buffer_view=False)
|
||||
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = a.e(BinaryOps.ADD, b)
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_reduceop_reshape_dont_push(self):
|
||||
@@ -1265,11 +1266,5 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.argmax(1)
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
class CycleBitcast(Function):
|
||||
def forward(self, x: LazyBuffer, allow_buffer_view=True):
|
||||
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view)
|
||||
return a.e(BinaryOps.ADD, b)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -91,7 +91,7 @@ class LazyBuffer:
|
||||
self.base.forced_realize = True
|
||||
return self
|
||||
|
||||
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
||||
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
||||
if self.dtype == dtype: return self
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
if self.is_unrealized_unmasked_const() and not bitcast:
|
||||
|
||||
Reference in New Issue
Block a user