diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index cc5c3e8e8b..7eacd108cd 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -323,10 +323,9 @@ class Embedding: self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size) def __call__(self, idx:Tensor) -> Tensor: - if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device) - arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp) - arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp) + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) + big_shp = idx.shape+(self.vocab_sz, self.embed_sz) + arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp) return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype) class LSTMCell: