mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
@@ -5,8 +5,8 @@ from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites, UPat
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites, view_left
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -624,6 +624,7 @@ class Kernel:
|
||||
|
||||
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
||||
def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
|
||||
st = ShapeTracker.from_shape(st.shape) # st needs to be contiguous
|
||||
wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
|
||||
tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
|
||||
|
||||
@@ -681,9 +682,7 @@ class Kernel:
|
||||
|
||||
return ret
|
||||
|
||||
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
|
||||
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
||||
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
|
||||
return graph_rewrite(fixup_ast(self.ast), view_left)
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user