fix: cast on transpose (#13653)

This commit is contained in:
wozeparrot
2025-12-11 21:03:49 -08:00
committed by GitHub
parent 950d8de00e
commit 8f60b8dd1e

View File

@@ -66,7 +66,10 @@ class Group:
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(src.shape[-1], track=False):
dst_store = dst[width, height, inner].store(src[height, width, inner]).end(height, width, inner)
src_load = src[height, width, inner]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[width, height, inner].store(src_load).end(height, width, inner)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)