mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
rewrite to program (#13808)
This commit is contained in:
@@ -123,6 +123,25 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
|
||||
newlst.extend(ret[1])
|
||||
return newlst
|
||||
|
||||
def do_linearize(prg:UOp, sink:UOp) -> UOp:
|
||||
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
|
||||
|
||||
def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
|
||||
src = ctx.render(list(lin.src))
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
|
||||
|
||||
pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
])
|
||||
|
||||
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink,))
|
||||
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
|
||||
|
||||
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen import full_rewrite_to_program
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
@@ -32,20 +32,16 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
try:
|
||||
uops = full_rewrite(ast, renderer)
|
||||
except RuntimeError as e:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(e)
|
||||
print(pyrender(ast))
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
|
||||
# print and render
|
||||
if DEBUG >= 6: print_uops(uops)
|
||||
src = renderer.render(uops)
|
||||
prg = full_rewrite_to_program(ast, renderer)
|
||||
# SINK/LINEAR/SOURCE
|
||||
sink, linear, source = prg.src
|
||||
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
# print
|
||||
if DEBUG >= 6: print_uops(list(linear.src))
|
||||
|
||||
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src),
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable
|
||||
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
# GROUP is a NOOP that just merges things together
|
||||
SINK = auto(); AFTER = auto(); GROUP = auto()
|
||||
|
||||
@@ -218,7 +218,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
|
||||
@@ -249,6 +249,14 @@ full_spec = PatternMatcher([
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user