mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
hotfix: store dtype for ImageDType (#6133)
This commit is contained in:
@@ -242,8 +242,8 @@ class Kernel:
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.bufs[0].src[2].dtype, ImageDType):
|
||||
base_shape = self.bufs[0].src[2].dtype.shape
|
||||
if isinstance(self.bufs[0].src[0].dtype, ImageDType):
|
||||
base_shape = self.bufs[0].src[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[sint, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
@@ -442,7 +442,7 @@ class Kernel:
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op is OptOps.UPCASTMID: # white
|
||||
check(cast(DType, self.bufs[0].src[2].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
||||
check(cast(DType, self.bufs[0].src[0].dtype).name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
||||
check(axes[0] == axis, "wrong axis")
|
||||
|
||||
Reference in New Issue
Block a user