no full_rewrite [pr] (#13809)

* no full_rewrite [pr]

* fix

* fix docs
This commit is contained in:
George Hotz
2025-12-22 23:20:01 -05:00
committed by GitHub
parent edce2303f4
commit 8dcba2e2cc
10 changed files with 51 additions and 76 deletions

View File

@@ -26,9 +26,9 @@ Transforms the ast into an optimized ast. This is where BEAM search and heuristi
## tinygrad/codegen ## tinygrad/codegen
Transform the optimized ast into a linearized list of UOps. Transform the optimized ast into a linearized and rendered program.
::: tinygrad.codegen.full_rewrite ::: tinygrad.codegen.full_rewrite_to_program
options: options:
members: false members: false
show_labels: false show_labels: false

View File

@@ -1 +1 @@
{"$schema": "https://opencode.ai/config.json", "formatter": false} {"$schema": "https://opencode.ai/config.json", "formatter": false, "lsp": false}

View File

@@ -1,10 +1,9 @@
# ruff: noqa: E501 E712 F401 # ruff: noqa: E501 E712 F401
from dataclasses import replace
from tinygrad import dtypes, Device from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import
from tinygrad.renderer import ProgramSpec from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, getenv from tinygrad.helpers import dedup, getenv
from tinygrad.device import Buffer from tinygrad.device import Buffer
from tinygrad.dtype import ImageDType, Invalid from tinygrad.dtype import ImageDType, Invalid
@@ -86,16 +85,11 @@ def dm_conv_172():
ast = {143: vision_conv_143, 153: vision_conv_153, 172: dm_conv_172}[getenv("NUM", 143)]() ast = {143: vision_conv_143, 153: vision_conv_153, 172: dm_conv_172}[getenv("NUM", 143)]()
compiler = Device.default.compiler
renderer = Device.default.renderer renderer = Device.default.renderer
allocator = Device.default.allocator allocator = Device.default.allocator
uops = full_rewrite(ast, renderer) ps = get_program(ast, renderer)
src = renderer.render(uops) cr = CompiledRunner(replace(ps, device=Device.DEFAULT))
lib = compiler.compile(src)
ps = ProgramSpec("conv", src, Device.DEFAULT, ast, uops)
cr = CompiledRunner(ps, precompiled=lib)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg) gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
# print(len(gs)) # print(len(gs))

View File

@@ -2,18 +2,27 @@ import os, time, struct, functools, unittest
from typing import Any, Callable from typing import Any, Callable
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner from tinygrad.engine.realize import Runner, get_program
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, CI from tinygrad.helpers import T, CI
from tinygrad.codegen import full_rewrite from tinygrad.renderer import Renderer
from tinygrad.codegen import full_rewrite_to_sink, line_rewrite, pm_linearize_cleanups
from tinygrad.codegen.late.linearizer import linearize
# decorator to skip slow tests by default, run with RUN_SLOW=1 to include them # decorator to skip slow tests by default, run with RUN_SLOW=1 to include them
slow = unittest.skipUnless(os.getenv("RUN_SLOW"), "slow test, set RUN_SLOW=1 to run") slow = unittest.skipUnless(os.getenv("RUN_SLOW"), "slow test, set RUN_SLOW=1 to run")
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
def get_uops(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""Extract linearized UOps from a sink. Test helper that only does linearization (no render)."""
if ren is None: ren = Renderer()
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
return line_rewrite(linearize(full_sink), pm_linearize_cleanups)
def derandomize_model(model): def derandomize_model(model):
for p in get_parameters(model): for p in get_parameters(model):
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype)) p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
@@ -51,13 +60,12 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
bufs = [] bufs = []
for buf_dt, data in inputs or []: for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize)) bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data))) allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
opts = PythonRenderer() prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer())
lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), opts) prog = PythonProgram("run", PythonCompiler().compile(prg.src))
prog = PythonProgram("run", PythonCompiler().compile(opts.render(lst)))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs) prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt).tolist()[0] return out_buf.cast(uop.dtype.fmt or "").tolist()[0]
def not_support_multi_device(): def not_support_multi_device():
# CL and CUDA don't support multi device if in CI # CL and CUDA don't support multi device if in CI

View File

@@ -1,30 +1,26 @@
import unittest import unittest
import numpy as np import numpy as np
from dataclasses import replace
from tinygrad.device import Buffer, Device, is_dtype_supported from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType from tinygrad.dtype import dtypes, ConstType
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.helpers import dedup, flatten, prod from tinygrad.helpers import prod
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.wgsl import WGSLRenderer from tinygrad.renderer.wgsl import WGSLRenderer
from tinygrad.runtime.ops_python import PythonRenderer from tinygrad.runtime.ops_python import PythonRenderer
from tinygrad.uop.ops import UOp, Ops, python_alu from tinygrad.uop.ops import UOp, Ops, python_alu
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None): def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
for x in inputs: x.realize() for x in inputs: x.realize()
# NOTE: we only toposort the stores uops = prg.uops
uops: list[UOp] = []
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs] inbufs = [x.uop.base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops) prg = replace(prg, device=Device.DEFAULT)
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", if local_size is not None: prg = replace(prg, local_size=local_size)
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) ei = CompiledRunner(prg)
ei.exec(outbufs+inbufs) ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
@@ -37,8 +33,8 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
alu = ld.alu(alu_op, *alu_src_uops) alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu) store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,)) sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) prg = get_program(sink, Device[Device.DEFAULT].renderer)
return _test_uop_result([Tensor([input_val])], uops)[0] return _test_uop_result([Tensor([input_val])], prg)[0]
class TestRendererFailures(unittest.TestCase): class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
@@ -47,8 +43,8 @@ class TestRendererFailures(unittest.TestCase):
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1))) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1]) np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
@@ -58,8 +54,8 @@ class TestRendererFailures(unittest.TestCase):
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1))) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0] ret = _test_uop_result([], prg, local_size=[4, 2, 1])[0]
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1]) np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle") @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
@@ -104,8 +100,8 @@ class TestPTXFailures(unittest.TestCase):
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val)) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1]) np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")

View File

@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
from tinygrad.codegen import full_rewrite from tinygrad.engine.realize import get_program
from tinygrad.dtype import DType from tinygrad.dtype import DType
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -869,9 +869,8 @@ class TestIdxUpcast(unittest.TestCase):
for s in schedule: for s in schedule:
if s.ast.op is Ops.SINK: if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer renderer = Device[s.bufs[0].device].renderer
uops = full_rewrite(s.ast, renderer) prg = get_program(s.ast, renderer)
renderer.render(uops) return prg.uops
return uops
def _assert(self, dtype: DType, a: Tensor): def _assert(self, dtype: DType, a: Tensor):
uops = self._schedule_render(a) uops = self._schedule_render(a)

View File

@@ -7,30 +7,27 @@ from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.schedule import ExecItem from tinygrad.engine.schedule import ExecItem
from tinygrad.codegen import full_rewrite
from tinygrad.uop.symbolic import sym from tinygrad.uop.symbolic import sym
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import get_uops
from dataclasses import replace
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]: def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u) sink = UOp.group(*u)
for r in sink.ranges: sink = sink.end(r) for r in sink.ranges: sink = sink.end(r)
# we strip the SINK here for legacy reasons # we strip the SINK here for legacy reasons
ret = full_rewrite(sink.sink(arg=KernelInfo(opts_to_apply=())), ren) ret = get_uops(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
assert ret[-1].op is Ops.SINK assert ret[-1].op is Ops.SINK
return ret[:-1] return ret[:-1]
def _uops_to_prg(uops_list): def _uops_to_prg(uops_list):
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer) prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
src = Device[Device.DEFAULT].renderer.render(uops) return CompiledRunner(replace(prg, device=Device.DEFAULT))
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg)) uops.append(UOp(uop, dtype, tuple(src), arg))

View File

@@ -4,7 +4,6 @@ from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import get_program from tinygrad.engine.realize import get_program
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
@@ -146,7 +145,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2)) u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3)) u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = full_rewrite(u5.sink()) uops = list(u5.toposort())
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
@@ -155,7 +154,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = globl.index(o2) u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3)) u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = full_rewrite(u4.sink()) uops_fma = list(u4.toposort())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@@ -3,8 +3,8 @@ import unittest, pickle, functools, math
import z3 import z3
from tinygrad.dtype import dtypes, ConstType, DType, Invalid from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context from tinygrad.helpers import Context
from test.helpers import get_uops
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
from tinygrad.uop.validate import uops_to_z3 from tinygrad.uop.validate import uops_to_z3
@@ -747,7 +747,7 @@ class TestSymbolic(unittest.TestCase):
# TODO: copied from render, render does not support cast # TODO: copied from render, render does not support cast
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()) uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))

View File

@@ -141,21 +141,3 @@ def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,)) sink = UOp(Ops.PROGRAM, src=(full_sink,))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render") return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.
Args:
sink: The Ops.SINK rooting the Kernel graph.
ren: The Renderer (can change how things are processed, fix this).
Returns:
Linear program in UOps.
"""
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
assert len(full_sink.ranges) == 0, f"all ranges must end by the sink, {full_sink.ranges}"
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups)
if SPEC: type_verify(lst, program_spec)
return lst