mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
refactor lower_schedule_item to pattern matchers [pr] (#8434)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Optional, cast, Generator
|
||||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.ops import Ops, UOp, Variable, sym_infer
|
||||
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
@@ -141,20 +141,16 @@ class ExecItem:
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
|
||||
if si.ast.op is Ops.SINK:
|
||||
runner = get_runner(si.outputs[0].device, si.ast)
|
||||
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
|
||||
out = si.outputs[0]
|
||||
if si.ast.op is Ops.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
raise RuntimeError(f"don't know how to lower {si.ast}")
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
||||
(UPat(Ops.EMPTY), lambda ctx: (EmptyOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(copy.size, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(copy.size, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
])
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
|
||||
|
||||
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
while len(schedule):
|
||||
|
||||
Reference in New Issue
Block a user