mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
do fusion locally (#10095)
* do fusion locally * oops, that's the right way * explicit delete closure
This commit is contained in:
@@ -30,10 +30,10 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te
|
||||
def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run()
|
||||
|
||||
class TestFuse(unittest.TestCase):
|
||||
def _test_fuse(self, fxn, *args, atol=1e-7, **kwargs):
|
||||
def _test_fuse(self, fxn, *args, atol=1e-7, allow_multiple=False, **kwargs):
|
||||
GlobalCounters.reset()
|
||||
out_single = fxn(*args, **kwargs).fuse()
|
||||
run_one_schedule_item(out_single)
|
||||
if not allow_multiple: run_one_schedule_item(out_single)
|
||||
np_single = out_single.numpy()
|
||||
GlobalCounters.reset()
|
||||
np_multi = fxn(*args, **kwargs).numpy()
|
||||
@@ -51,6 +51,11 @@ class TestFuse(unittest.TestCase):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.softmax(axis=-1), a)
|
||||
|
||||
def test_fuse_gemm_softmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
b = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
def test_fuse_softmax_dtype(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
|
||||
Reference in New Issue
Block a user