mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
always swizzle load st in wmma [pr] (#7908)
This commit is contained in:
@@ -10,7 +10,7 @@ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad,
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
@@ -638,7 +638,7 @@ class Kernel:
|
||||
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
|
||||
if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
|
||||
if tc_pattern: srcs[i] = srcs[i].view(fix_st(srcs[i].st_arg if srcs[i].op is Ops.LOAD else srcs[i].src[0].st_arg, *tc_pattern))
|
||||
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
||||
|
||||
Reference in New Issue
Block a user