mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
early uop globals with Buffer (#6753)
This commit is contained in:
@@ -41,6 +41,10 @@ class LBScheduleItem:
|
||||
@property
|
||||
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
bufs: Tuple[Buffer, ...]
|
||||
|
||||
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
@@ -112,9 +116,14 @@ reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def full_ast_rewrite(sink:UOp) -> UOp:
|
||||
enumerate_bufs = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda ctx,x: x.replace(arg=ctx.bufs.index(x.arg)) if isinstance(x.arg, Buffer) else None),
|
||||
])
|
||||
|
||||
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
if not AST_REWRITE: return sink
|
||||
return graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
return graph_rewrite(sink, enumerate_bufs, ctx)
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
@@ -145,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer)
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
@@ -174,16 +182,16 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for i, out in enumerate(outs):
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
@@ -10,12 +10,14 @@ from tinygrad.helpers import Context, getenv, to_function_name
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite
|
||||
|
||||
# **** /graph - detailed UOp + rewrites
|
||||
|
||||
# NOTE: UPats in ops.py are spec
|
||||
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
|
||||
# TODO: fix key for uop with buffer
|
||||
def graph_rewrites(ctx:TrackedRewriteContext):
|
||||
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RewriteLocation:
|
||||
@@ -95,7 +97,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
code = ""
|
||||
for ctx in contexts:
|
||||
if ctx.loc[0].split("/")[-1] == "schedule.py":
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
|
||||
si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.DEFINE_GLOBAL))
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src
|
||||
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
|
||||
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
|
||||
|
||||
Reference in New Issue
Block a user