Handling Multiple UnaryOps.BITCAST in Function for Proper Kernel Fusion [run_process_replay] (#5172)

* [Patch] added an option not to ignore view replacing when doing bitcast

* added the testcase

* [Add] reproduced bitcast cannot be fused into a single kernel in the unittest

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
hikettei
2024-07-06 01:16:44 +09:00
committed by GitHub
parent 43c3f73fbc
commit 1ab7a4cff0
3 changed files with 22 additions and 4 deletions

View File

@@ -7,13 +7,15 @@ import numpy as np
from typing import List, Optional, Union
from tinygrad import nn, dtypes
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps, UnaryOps
from tinygrad.helpers import DEBUG, flatten, getenv
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.graph import print_tree
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported
from tinygrad.function import Function
from tinygrad.lazy import LazyBuffer
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
@@ -40,6 +42,22 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
l.linearize()
return sched
class CycleBitcast(Function):
def bitwise_cast(self, x: LazyBuffer):
return x.cast(dtypes.int32, True, True)
def forward(self, x: LazyBuffer):
x = x.cast(dtypes.float32)
a = self.bitwise_cast(x)
b = self.bitwise_cast(x.e(UnaryOps.NEG))
return a.e(BinaryOps.ADD, b)
class TestUOpSchedule(unittest.TestCase):
def test_multiple_bitcast_in_function(self):
a = Tensor.empty()
b = CycleBitcast.apply(a)
check_schedule(b, 1)
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)