mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tc passes
This commit is contained in:
@@ -19,7 +19,7 @@ from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, blo
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen, pm_fix_bufferize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -76,6 +76,9 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing+pm_group_for_reduce, name="postopt symbolic"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_fix_bufferize, name="fix bufferize"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
|
||||
@@ -257,12 +257,12 @@ class Scheduler:
|
||||
except KernelOptError: continue
|
||||
|
||||
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
||||
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
||||
#warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
||||
ne: list[UOp] = []
|
||||
for opt in tc.opts:
|
||||
if opt[0] == "l":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
|
||||
warp //= 2
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.WARP) #, input_new_rng=warp%2)
|
||||
#warp //= 2
|
||||
elif opt[0] == "u":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
|
||||
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
|
||||
|
||||
@@ -576,6 +576,18 @@ pm_add_buffers = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
])+_pm_add_buffers
|
||||
|
||||
def fix_bufferize(x:UOp, y:UOp):
|
||||
local_end_ranges = []
|
||||
for r in x.ranges:
|
||||
if r.arg[-1] in {AxisType.LOCAL, AxisType.WARP}:
|
||||
local_end_ranges.append(r)
|
||||
if len(local_end_ranges) == 0: return None
|
||||
return UOp(Ops.INDEX, y.dtype, (x.replace(src=x.src+tuple(local_end_ranges)),)+y.src[1:]+tuple(local_end_ranges), arg=y.arg)
|
||||
|
||||
pm_fix_bufferize = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFERIZE, name="x"),), name="y", allow_any_len=True), fix_bufferize),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 5. split into kernels
|
||||
|
||||
|
||||
Reference in New Issue
Block a user