hotfix: store dtype for ImageDType (#6133)

This commit is contained in:
qazal
2024-08-17 18:44:53 +08:00
committed by GitHub
parent d0513087e1
commit 151a62ad32

View File

@@ -242,8 +242,8 @@ class Kernel:
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.bufs[0].src[2].dtype, ImageDType):
base_shape = self.bufs[0].src[2].dtype.shape
if isinstance(self.bufs[0].src[0].dtype, ImageDType):
base_shape = self.bufs[0].src[0].dtype.shape
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
special_strides: Tuple[sint, ...] = tuple()
for i,g in enumerate(shape_idx_groups):
@@ -442,7 +442,7 @@ class Kernel:
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op is OptOps.UPCASTMID: # white
check(cast(DType, self.bufs[0].src[2].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
check(cast(DType, self.bufs[0].src[0].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
axes = self.sts[0].unit_stride_axes()
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
check(axes[0] == axis, "wrong axis")