diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index 6c1c8e874f..d10792d917 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -59,9 +59,7 @@ class EmbeddingBert(nn.Embedding): arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp) - # TODO: contiguous() here because the embedding dropout creates different asts on each device, and search becomes very slow. - # Should fix with fixing random ast on multi device, and fuse arange to make embedding fast. - return (arange == idx).mul(vals).sum(2, dtype=vals.dtype).contiguous() + return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype) class LayerNormBert: def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):