mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
ast_fixup in one graph_rewrite pass [pr] (#9444)
This commit is contained in:
@@ -274,7 +274,7 @@ DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
# partial assign can store to a non-contiguous ShapeTracker
|
||||
@@ -342,12 +342,12 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
# ** unbind variables
|
||||
|
||||
def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
||||
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp) -> UOp|None:
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx.update(var_vals)
|
||||
return st.to_uop() if st != x.st else None
|
||||
ctx[0].update(var_vals)
|
||||
return x.replace(arg=st) if st != x.st else None
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
@@ -387,13 +387,11 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
ast = k.arg.ast.substitute(parents_rep)
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
|
||||
# add buffer ops + fix_kernel_ops
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# create subbuffer (TODO: this does not belong here)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
|
||||
Reference in New Issue
Block a user