diff --git a/test/test_nn.py b/test/test_nn.py index e46b34a780..96ef993d43 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2,7 +2,7 @@ import unittest import numpy as np import torch -from tinygrad import Tensor, Device, TinyJit +from tinygrad import Tensor, Device, TinyJit, dtypes from tinygrad.uop.ops import Ops from tinygrad.helpers import GlobalCounters, CI, Context from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding @@ -465,7 +465,7 @@ class TestNN(unittest.TestCase): # used to fail bounds check with Context(FUSE_ARANGE=1): embedding = Embedding(100, 1024) - input_ids = Tensor.empty(16, 16) + input_ids = Tensor.empty(16, 16, dtype=dtypes.int) embedding(input_ids).realize() def test_load_state_dict(self): diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index bf6dca75a0..d32a3d5e2f 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -320,6 +320,7 @@ class Embedding: def __call__(self, idx:Tensor) -> Tensor: if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) + if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}") big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp) return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)