no full_rewrite [pr] (#13809)

* no full_rewrite [pr]

* fix

* fix docs
This commit is contained in:
George Hotz
2025-12-22 23:20:01 -05:00
committed by GitHub
parent edce2303f4
commit 8dcba2e2cc
10 changed files with 51 additions and 76 deletions

View File

@@ -26,9 +26,9 @@ Transforms the ast into an optimized ast. This is where BEAM search and heuristi
## tinygrad/codegen
Transform the optimized ast into a linearized list of UOps.
Transform the optimized ast into a linearized and rendered program.
::: tinygrad.codegen.full_rewrite
::: tinygrad.codegen.full_rewrite_to_program
options:
members: false
show_labels: false

View File

@@ -1 +1 @@
{"$schema": "https://opencode.ai/config.json", "formatter": false}
{"$schema": "https://opencode.ai/config.json", "formatter": false, "lsp": false}

View File

@@ -1,10 +1,9 @@
# ruff: noqa: E501 E712 F401
from dataclasses import replace
from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.helpers import dedup, getenv
from tinygrad.device import Buffer
from tinygrad.dtype import ImageDType, Invalid
@@ -86,16 +85,11 @@ def dm_conv_172():
ast = {143: vision_conv_143, 153: vision_conv_153, 172: dm_conv_172}[getenv("NUM", 143)]()
compiler = Device.default.compiler
renderer = Device.default.renderer
allocator = Device.default.allocator
uops = full_rewrite(ast, renderer)
src = renderer.render(uops)
lib = compiler.compile(src)
ps = ProgramSpec("conv", src, Device.DEFAULT, ast, uops)
cr = CompiledRunner(ps, precompiled=lib)
ps = get_program(ast, renderer)
cr = CompiledRunner(replace(ps, device=Device.DEFAULT))
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
# print(len(gs))

View File

@@ -2,18 +2,27 @@ import os, time, struct, functools, unittest
from typing import Any, Callable
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.engine.realize import Runner, get_program
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, CI
from tinygrad.codegen import full_rewrite
from tinygrad.renderer import Renderer
from tinygrad.codegen import full_rewrite_to_sink, line_rewrite, pm_linearize_cleanups
from tinygrad.codegen.late.linearizer import linearize
# decorator to skip slow tests by default, run with RUN_SLOW=1 to include them
slow = unittest.skipUnless(os.getenv("RUN_SLOW"), "slow test, set RUN_SLOW=1 to run")
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
def get_uops(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""Extract linearized UOps from a sink. Test helper that only does linearization (no render)."""
if ren is None: ren = Renderer()
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
return line_rewrite(linearize(full_sink), pm_linearize_cleanups)
def derandomize_model(model):
for p in get_parameters(model):
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
@@ -51,13 +60,12 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
bufs = []
for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data)))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
opts = PythonRenderer()
lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), opts)
prog = PythonProgram("run", PythonCompiler().compile(opts.render(lst)))
prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer())
prog = PythonProgram("run", PythonCompiler().compile(prg.src))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt).tolist()[0]
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]
def not_support_multi_device():
# CL and CUDA don't support multi device if in CI

View File

@@ -1,30 +1,26 @@
import unittest
import numpy as np
from dataclasses import replace
from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.helpers import prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.wgsl import WGSLRenderer
from tinygrad.runtime.ops_python import PythonRenderer
from tinygrad.uop.ops import UOp, Ops, python_alu
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
for x in inputs: x.realize()
# NOTE: we only toposort the stores
uops: list[UOp] = []
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
uops = prg.uops
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops)
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
prg = replace(prg, device=Device.DEFAULT)
if local_size is not None: prg = replace(prg, local_size=local_size)
ei = CompiledRunner(prg)
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
@@ -37,8 +33,8 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
return _test_uop_result([Tensor([input_val])], uops)[0]
prg = get_program(sink, Device[Device.DEFAULT].renderer)
return _test_uop_result([Tensor([input_val])], prg)[0]
class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
@@ -47,8 +43,8 @@ class TestRendererFailures(unittest.TestCase):
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
@@ -58,8 +54,8 @@ class TestRendererFailures(unittest.TestCase):
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], prg, local_size=[4, 2, 1])[0]
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
@@ -104,8 +100,8 @@ class TestPTXFailures(unittest.TestCase):
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
prg = get_program(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")

View File

@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad.codegen import full_rewrite
from tinygrad.engine.realize import get_program
from tinygrad.dtype import DType
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -869,9 +869,8 @@ class TestIdxUpcast(unittest.TestCase):
for s in schedule:
if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer
uops = full_rewrite(s.ast, renderer)
renderer.render(uops)
return uops
prg = get_program(s.ast, renderer)
return prg.uops
def _assert(self, dtype: DType, a: Tensor):
uops = self._schedule_render(a)

View File

@@ -7,30 +7,27 @@ from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.schedule import ExecItem
from tinygrad.codegen import full_rewrite
from tinygrad.uop.symbolic import sym
from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import get_uops
from dataclasses import replace
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u)
for r in sink.ranges: sink = sink.end(r)
# we strip the SINK here for legacy reasons
ret = full_rewrite(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
ret = get_uops(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
assert ret[-1].op is Ops.SINK
return ret[:-1]
def _uops_to_prg(uops_list):
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer)
src = Device[Device.DEFAULT].renderer.render(uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
return CompiledRunner(replace(prg, device=Device.DEFAULT))
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg))

View File

@@ -4,7 +4,6 @@ from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
@@ -146,7 +145,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = full_rewrite(u5.sink())
uops = list(u5.toposort())
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
@@ -155,7 +154,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = full_rewrite(u4.sink())
uops_fma = list(u4.toposort())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@@ -3,8 +3,8 @@ import unittest, pickle, functools, math
import z3
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context
from test.helpers import get_uops
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
from tinygrad.uop.validate import uops_to_z3
@@ -747,7 +747,7 @@ class TestSymbolic(unittest.TestCase):
# TODO: copied from render, render does not support cast
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))

View File

@@ -141,21 +141,3 @@ def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.
Args:
sink: The Ops.SINK rooting the Kernel graph.
ren: The Renderer (can change how things are processed, fix this).
Returns:
Linear program in UOps.
"""
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
assert len(full_sink.ranges) == 0, f"all ranges must end by the sink, {full_sink.ranges}"
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups)
if SPEC: type_verify(lst, program_spec)
return lst