tc passes

This commit is contained in:
George Hotz
2025-10-02 14:27:52 +08:00
parent f32a497f08
commit 2fbd7d21f9
3 changed files with 19 additions and 4 deletions

View File

@@ -19,7 +19,7 @@ from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, blo
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen, pm_fix_bufferize
@dataclass
class RewriteStep:
@@ -76,6 +76,9 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing+pm_group_for_reduce, name="postopt symbolic"))
# add locals
ret.append(RewriteStep(pm_fix_bufferize, name="fix bufferize"))
# add locals
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))

View File

@@ -257,12 +257,12 @@ class Scheduler:
except KernelOptError: continue
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
#warp = UOp.range(tc.threads, -1, AxisType.WARP)
ne: list[UOp] = []
for opt in tc.opts:
if opt[0] == "l":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
warp //= 2
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.WARP) #, input_new_rng=warp%2)
#warp //= 2
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")

View File

@@ -576,6 +576,18 @@ pm_add_buffers = PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
])+_pm_add_buffers
def fix_bufferize(x:UOp, y:UOp):
local_end_ranges = []
for r in x.ranges:
if r.arg[-1] in {AxisType.LOCAL, AxisType.WARP}:
local_end_ranges.append(r)
if len(local_end_ranges) == 0: return None
return UOp(Ops.INDEX, y.dtype, (x.replace(src=x.src+tuple(local_end_ranges)),)+y.src[1:]+tuple(local_end_ranges), arg=y.arg)
pm_fix_bufferize = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFERIZE, name="x"),), name="y", allow_any_len=True), fix_bufferize),
])
# *****************
# 5. split into kernels