From 6e47dc8921228627770e770d51db34535eb43b9b Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Fri, 29 Nov 2024 03:39:46 -0300 Subject: [PATCH] true tc swizzle [pr] (#7951) * true tc swizzle * cleanup * fix linter --- tinygrad/codegen/kernel.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index ab7fbbae41..1d069792dd 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,8 +5,8 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto -from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ - graph_rewrite, track_rewrites, UPat +from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, print_uops, type_verify, resolve, Variable, sint, \ + graph_rewrite, track_rewrites, view_left from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType @@ -624,6 +624,7 @@ class Kernel: if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3): def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern): + st = ShapeTracker.from_shape(st.shape) # st needs to be contiguous wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads) tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes) @@ -681,9 +682,7 @@ class Kernel: return ret - return graph_rewrite(fixup_ast(self.ast), PatternMatcher([ - (UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))), - (UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))])) + return graph_rewrite(fixup_ast(self.ast), view_left) # **** this is the lowerer ****