early uop globals with Buffer (#6753)

This commit is contained in:
qazal
2024-09-26 15:34:21 +08:00
committed by GitHub
parent e999281502
commit 197f8fd986
2 changed files with 21 additions and 10 deletions

View File

@@ -41,6 +41,10 @@ class LBScheduleItem:
@property
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
@dataclass(frozen=True)
class ScheduleItemContext:
bufs: Tuple[Buffer, ...]
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
# ** helpers for doing movementops on uops
@@ -112,9 +116,14 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def full_ast_rewrite(sink:UOp) -> UOp:
enumerate_bufs = PatternMatcher([
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda ctx,x: x.replace(arg=ctx.bufs.index(x.arg)) if isinstance(x.arg, Buffer) else None),
])
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return sink
return graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, reduceop_fusor)
return graph_rewrite(sink, enumerate_bufs, ctx)
# *** List[LazyBuffer] lowering to ScheduleItem ***
@@ -145,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer)
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# reduce ops change ShapeTracker
@@ -174,16 +182,16 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
ast: List[UOp] = []
inputs: List[LazyBuffer] = []
for i, out in enumerate(outs):
for out in outs:
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer)
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***

View File

@@ -10,12 +10,14 @@ from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import full_ast_rewrite
from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite
# **** /graph - detailed UOp + rewrites
# NOTE: UPats in ops.py are spec
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
# TODO: fix key for uop with buffer
def graph_rewrites(ctx:TrackedRewriteContext):
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))]
@dataclass(frozen=True)
class RewriteLocation:
@@ -95,7 +97,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
code = ""
for ctx in contexts:
if ctx.loc[0].split("/")[-1] == "schedule.py":
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.DEFINE_GLOBAL))
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx