diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index d23a775860..ab7fbbae41 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -10,7 +10,7 @@ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap +from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape @@ -638,7 +638,7 @@ class Kernel: srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]): - if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern)) + if tc_pattern: srcs[i] = srcs[i].view(fix_st(srcs[i].st_arg if srcs[i].op is Ops.LOAD else srcs[i].src[0].st_arg, *tc_pattern)) if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))