failing test for slow embedding kernel with FUSE_ARANGE=1 [pr] (#10330)

This commit is contained in:
qazal
2025-05-15 14:58:11 +03:00
committed by GitHub
parent 5f03688280
commit 7cfe367c07

View File

@@ -4,7 +4,7 @@ import numpy as np
import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad.ops import Ops
from tinygrad.helpers import CI, Context, OSX
from tinygrad.helpers import GlobalCounters, CI, Context, OSX
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
from tinygrad.nn.state import load_state_dict
@@ -513,14 +513,15 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
def test_embedding_one_kernel(self):
def test_embedding_one_kernel(self, ops=41410, kcount=3):
GlobalCounters.reset()
layer = Embedding(20, 30)
layer.weight = Tensor.zeros_like(layer.weight).contiguous()
a = Tensor([[1, 5, 9, 11],
[12, 19, 8, 1]])
result = layer(a)
schedule = result.schedule()
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
self.assertEqual(kcount, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes weight and embedding")
run_schedule(schedule)
b = Tensor([[1, 2, 3],
@@ -530,6 +531,17 @@ class TestNN(unittest.TestCase):
schedule = result.schedule()
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
run_schedule(schedule)
print(f"Embedding used {GlobalCounters.global_ops} ops")
self.assertLessEqual(GlobalCounters.global_ops, ops)
# TODO: fused with opts uses more ops
def test_embedding_one_kernel_fused(self):
with Context(FUSE_ARANGE=1, NOOPT=0):
self.test_embedding_one_kernel(ops=612_000, kcount=2)
def test_embedding_one_kernel_fused_noopt(self):
with Context(FUSE_ARANGE=1, NOOPT=1):
self.test_embedding_one_kernel(ops=0, kcount=2)
def test_embedding_shape(self):
vocab_size, embed_size = 10, 16