test schedule of LazyBuffers [run_process_replay] (#5859)

This commit is contained in:
qazal
2024-08-02 00:06:29 +08:00
committed by GitHub
parent 0e34d83777
commit 26d0265d66
2 changed files with 16 additions and 21 deletions

View File

@@ -4,7 +4,7 @@
import unittest
import numpy as np
from typing import List, Optional, Union
from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
@@ -14,19 +14,16 @@ from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported, Context
from tinygrad.function import Function
from tinygrad.lazy import LazyBuffer, view_supported_devices
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if isinstance(t, Tensor): t = [t]
seen = set()
def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if isinstance(t, Tensor): outs = t.lazydata.lbs
elif isinstance(t, List): outs = flatten([r.lazydata.lbs for r in t])
else: outs = [t]
if to_prerealize:
for pre in to_prerealize:
for s in pre.schedule(seen=seen.copy()):
for i,out in enumerate(s.outputs):
seen.add(out)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
for pre in to_prerealize: pre.schedule()
sched = create_schedule(outs)
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
@@ -1250,13 +1247,17 @@ class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
def test_bitcast_subbufer(self):
a = Tensor.empty(1, dtype=dtypes.float32).realize()
b = CycleBitcast.apply(a)
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view=True)
b = x.cast(dtypes.int32, True, allow_buffer_view=True)
b = a.e(BinaryOps.ADD, b)
check_schedule(b, 2) # this should fuse when it makes sense
def test_bitcast_disable_subbufer(self):
a = Tensor.empty(1, dtype=dtypes.float32).realize()
b = CycleBitcast.apply(a, allow_buffer_view=False)
x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.e(BinaryOps.ADD, b)
check_schedule(b, 1)
def test_reduceop_reshape_dont_push(self):
@@ -1265,11 +1266,5 @@ class TestSchedule(unittest.TestCase):
out = x.argmax(1)
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
class CycleBitcast(Function):
def forward(self, x: LazyBuffer, allow_buffer_view=True):
a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view)
b = x.cast(dtypes.int32, True, allow_buffer_view)
return a.e(BinaryOps.ADD, b)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -91,7 +91,7 @@ class LazyBuffer:
self.base.forced_realize = True
return self
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
if self.is_unrealized_unmasked_const() and not bitcast: