From 98b2854f14eafc458cc49b14f7bbba28c0cfe0b2 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 29 Dec 2024 11:10:07 +0200 Subject: [PATCH] refactor lower_schedule_item to pattern matchers [pr] (#8434) --- tinygrad/engine/realize.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a8370fd103..c8847d3d75 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,8 +1,8 @@ from typing import Optional, cast, Generator import time, pprint from dataclasses import dataclass, replace -from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA -from tinygrad.ops import Ops, UOp, Variable, sym_infer +from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA +from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.codegen.kernel import Kernel @@ -141,20 +141,16 @@ class ExecItem: self.prg.first_run = False return et -def lower_schedule_item(si:ScheduleItem) -> ExecItem: - assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY - if si.ast.op is Ops.SINK: - runner = get_runner(si.outputs[0].device, si.ast) - return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata) - out = si.outputs[0] - if si.ast.op is Ops.COPY: - kernel_type = BufferCopy - if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: - kernel_type = BufferXfer - return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs)) - if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) - if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs)) - raise RuntimeError(f"don't know how to lower {si.ast}") +# NOTE: ctx is the buffers +si_lowerer = PatternMatcher([ + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])), + (UPat(Ops.EMPTY), lambda ctx: (EmptyOp(ctx[0]), list(ctx))), + (UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))), + (UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(copy.size, ctx[0].device, ctx[1].device) \ + if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ + else BufferCopy(copy.size, ctx[0].device, ctx[1].device)), list(ctx))), +]) +def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata) def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]: while len(schedule):