From 151a62ad3256ede38b23415a528af7bf68fb323c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 17 Aug 2024 18:44:53 +0800 Subject: [PATCH] hotfix: store dtype for ImageDType (#6133) --- tinygrad/codegen/kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index fccde8f248..3ebb4914fa 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -242,8 +242,8 @@ class Kernel: shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - if isinstance(self.bufs[0].src[2].dtype, ImageDType): - base_shape = self.bufs[0].src[2].dtype.shape + if isinstance(self.bufs[0].src[0].dtype, ImageDType): + base_shape = self.bufs[0].src[0].dtype.shape if shape_idx_groups := get_contraction(self.output_shape, base_shape): special_strides: Tuple[sint, ...] = tuple() for i,g in enumerate(shape_idx_groups): @@ -442,7 +442,7 @@ class Kernel: self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op is OptOps.UPCASTMID: # white - check(cast(DType, self.bufs[0].src[2].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501 + check(cast(DType, self.bufs[0].src[0].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501 axes = self.sts[0].unit_stride_axes() check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}") check(axes[0] == axis, "wrong axis")