mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
cleanup Embedding call [pr] (#7869)
reshape on self.weight is noop, and don't need special case for numel 0.
This commit is contained in:
@@ -323,10 +323,9 @@ class Embedding:
|
||||
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
|
||||
arange_shp, weight_shp, big_shp = (self.vocab_sz, 1), (self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
|
||||
|
||||
class LSTMCell:
|
||||
|
||||
Reference in New Issue
Block a user