mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
failing test for slow embedding kernel with FUSE_ARANGE=1 [pr] (#10330)
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI, Context, OSX
|
||||
from tinygrad.helpers import GlobalCounters, CI, Context, OSX
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
@@ -513,14 +513,15 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
||||
|
||||
def test_embedding_one_kernel(self):
|
||||
def test_embedding_one_kernel(self, ops=41410, kcount=3):
|
||||
GlobalCounters.reset()
|
||||
layer = Embedding(20, 30)
|
||||
layer.weight = Tensor.zeros_like(layer.weight).contiguous()
|
||||
a = Tensor([[1, 5, 9, 11],
|
||||
[12, 19, 8, 1]])
|
||||
result = layer(a)
|
||||
schedule = result.schedule()
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
|
||||
self.assertEqual(kcount, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes weight and embedding")
|
||||
run_schedule(schedule)
|
||||
|
||||
b = Tensor([[1, 2, 3],
|
||||
@@ -530,6 +531,17 @@ class TestNN(unittest.TestCase):
|
||||
schedule = result.schedule()
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
|
||||
run_schedule(schedule)
|
||||
print(f"Embedding used {GlobalCounters.global_ops} ops")
|
||||
self.assertLessEqual(GlobalCounters.global_ops, ops)
|
||||
|
||||
# TODO: fused with opts uses more ops
|
||||
def test_embedding_one_kernel_fused(self):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=0):
|
||||
self.test_embedding_one_kernel(ops=612_000, kcount=2)
|
||||
|
||||
def test_embedding_one_kernel_fused_noopt(self):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=1):
|
||||
self.test_embedding_one_kernel(ops=0, kcount=2)
|
||||
|
||||
def test_embedding_shape(self):
|
||||
vocab_size, embed_size = 10, 16
|
||||
|
||||
Reference in New Issue
Block a user