add contiguous to EmbeddingBert output (#9829)

for some reason with random dropout it creates different ast on each device. And search embedding is slow. This workaround saved 6 minutes setup time on mi300x (25->19) and resulted in similar speed
This commit is contained in:
chenyu
2025-04-10 04:31:19 -04:00
committed by GitHub
parent fd4f06e623
commit 817746b30e

View File

@@ -53,7 +53,9 @@ class EmbeddingBert(nn.Embedding):
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype)
# TODO: contiguous() here because the embedding dropout creates different asts on each device, and search becomes very slow.
# Should fix with fixing random ast on multi device, and fuse arange to make embedding fast.
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype).contiguous()
class LayerNormBert:
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):