always swizzle load st in wmma [pr] (#7908)

This commit is contained in:
qazal
2024-11-26 07:00:58 -05:00
committed by GitHub
parent 4e5bf9dc7a
commit ceda43ce75

View File

@@ -10,7 +10,7 @@ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad,
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
@@ -638,7 +638,7 @@ class Kernel:
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
if tc_pattern: srcs[i] = srcs[i].view(fix_st(srcs[i].st_arg if srcs[i].op is Ops.LOAD else srcs[i].src[0].st_arg, *tc_pattern))
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))