refactor lower_schedule_item to pattern matchers [pr] (#8434)

This commit is contained in:
qazal
2024-12-29 11:10:07 +02:00
committed by GitHub
parent 0fd6d7482b
commit 98b2854f14

View File

@@ -1,8 +1,8 @@
from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.ops import Ops, UOp, Variable, sym_infer
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen.kernel import Kernel
@@ -141,20 +141,16 @@ class ExecItem:
self.prg.first_run = False
return et
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
if si.ast.op is Ops.SINK:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
out = si.outputs[0]
if si.ast.op is Ops.COPY:
kernel_type = BufferCopy
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
kernel_type = BufferXfer
return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs))
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {si.ast}")
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
(UPat(Ops.EMPTY), lambda ctx: (EmptyOp(ctx[0]), list(ctx))),
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(copy.size, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(copy.size, ctx[0].device, ctx[1].device)), list(ctx))),
])
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule):