Fix: Ensure DTW cost tensor is on the same device as input tensor (#2561)

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
Nathan Harmon
2025-06-25 18:42:09 -06:00
committed by GitHub
parent f50c4f264e
commit 679ae1d141

View File

@@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](