mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
pm_add_buffers_local
This commit is contained in:
@@ -19,7 +19,7 @@ from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, blo
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -81,7 +81,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
ret.append(RewriteStep(pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
|
||||
@@ -519,7 +519,7 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
|
||||
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
|
||||
# NOTE: this has been fixed up a bit
|
||||
|
||||
def bufferize_to_store(x:UOp):
|
||||
def bufferize_to_store(x:UOp, allow_locals=True):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
size = prod(shape)
|
||||
@@ -552,6 +552,7 @@ def bufferize_to_store(x:UOp):
|
||||
return ret.replace(tag=x.tag)
|
||||
|
||||
# handle locals
|
||||
if not allow_locals: return None
|
||||
tag = x.arg.device
|
||||
if tag is None: tag = UOp.unique().arg # TODO: hack
|
||||
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
|
||||
@@ -560,13 +561,17 @@ def bufferize_to_store(x:UOp):
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
(UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, allow_locals=False)),
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0] for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_add_buffers+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 5. split into kernels
|
||||
|
||||
|
||||
Reference in New Issue
Block a user