hot fix explicitly set arange dtype to float (#2772)

This commit is contained in:
chenyu
2023-12-14 23:14:38 -05:00
committed by GitHub
parent c0f76ed4ea
commit 9afa8009c1

View File

@@ -250,7 +250,8 @@ class Upsample:
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
# TODO: remove explicit dtypes after broadcast fix
freqs = (-math.log(max_period) * Tensor.arange(half, dtype=dtypes.float32) / half).exp()
args = timesteps * freqs
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)