true tc swizzle [pr] (#7951)

* true tc swizzle

* cleanup

* fix linter
This commit is contained in:
ignaciosica
2024-11-29 03:39:46 -03:00
committed by GitHub
parent 765096fe7d
commit 6e47dc8921

View File

@@ -5,8 +5,8 @@ from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
graph_rewrite, track_rewrites, UPat
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, print_uops, type_verify, resolve, Variable, sint, \
graph_rewrite, track_rewrites, view_left
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
from tinygrad.dtype import ImageDType
@@ -624,6 +624,7 @@ class Kernel:
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
st = ShapeTracker.from_shape(st.shape) # st needs to be contiguous
wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
@@ -681,9 +682,7 @@ class Kernel:
return ret
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
return graph_rewrite(fixup_ast(self.ast), view_left)
# **** this is the lowerer ****