mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
remove contiguous and use where in EmbeddingBert (#13632)
This commit is contained in:
@@ -59,9 +59,7 @@ class EmbeddingBert(nn.Embedding):
|
||||
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
||||
# TODO: contiguous() here because the embedding dropout creates different asts on each device, and search becomes very slow.
|
||||
# Should fix with fixing random ast on multi device, and fuse arange to make embedding fast.
|
||||
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype).contiguous()
|
||||
return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype)
|
||||
|
||||
class LayerNormBert:
|
||||
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
|
||||
|
||||
Reference in New Issue
Block a user