cleanup Embedding call [pr] (#7869)

reshape on self.weight is noop, and don't need special case for numel 0.
This commit is contained in:
chenyu
2024-11-24 07:32:26 -05:00
committed by GitHub
parent ad9df26fba
commit 31337b49e3

View File

@@ -323,10 +323,9 @@ class Embedding:
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
class LSTMCell: