remove contiguous and use where in EmbeddingBert (#13632)

This commit is contained in:
chenyu
2025-12-09 15:49:21 -05:00
committed by GitHub
parent ddecba300f
commit 016a59cafa

View File

@@ -59,9 +59,7 @@ class EmbeddingBert(nn.Embedding):
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
# TODO: contiguous() here because the embedding dropout creates different asts on each device, and search becomes very slow.
# Should fix with fixing random ast on multi device, and fuse arange to make embedding fast.
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype).contiguous()
return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype)
class LayerNormBert:
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):