mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Handling Multiple UnaryOps.BITCAST in Function for Proper Kernel Fusion [run_process_replay] (#5172)
* [Patch] added an option not to ignore view replacing when doing bitcast * added the testcase * [Add] reproduced bitcast cannot be fused into a single kernel in the unittest --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -7,13 +7,15 @@ import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps, UnaryOps
|
||||
from tinygrad.helpers import DEBUG, flatten, getenv
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported
|
||||
from tinygrad.function import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
@@ -40,6 +42,22 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
l.linearize()
|
||||
return sched
|
||||
|
||||
class CycleBitcast(Function):
|
||||
def bitwise_cast(self, x: LazyBuffer):
|
||||
return x.cast(dtypes.int32, True, True)
|
||||
|
||||
def forward(self, x: LazyBuffer):
|
||||
x = x.cast(dtypes.float32)
|
||||
a = self.bitwise_cast(x)
|
||||
b = self.bitwise_cast(x.e(UnaryOps.NEG))
|
||||
return a.e(BinaryOps.ADD, b)
|
||||
|
||||
class TestUOpSchedule(unittest.TestCase):
|
||||
def test_multiple_bitcast_in_function(self):
|
||||
a = Tensor.empty()
|
||||
b = CycleBitcast.apply(a)
|
||||
check_schedule(b, 1)
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
|
||||
Reference in New Issue
Block a user