postrange boilerplate work (#11881)

This commit is contained in:
George Hotz
2025-08-27 15:22:59 -07:00
committed by GitHub
parent fd579433bc
commit cb5295168d
4 changed files with 30 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY, POSTOPT
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
@@ -18,6 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
@dataclass
@@ -45,10 +46,10 @@ rewrites_for_linearizer = [
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value, POSTOPT.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL, _RANGEIFY, _POSTOPT) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
@@ -56,12 +57,14 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
ret.extend(rewrites_for_views)
# this is kernel.py
ret.append(RewriteStep(pm_get_optimization, ctx=lambda _: opts, name="get optimization"))
ret.append(RewriteStep(pm_do_optimize, ctx=lambda _: opts, name="optimize ast"))
if not _RANGEIFY: ret.append(RewriteStep(pm_get_optimization, ctx=lambda _: opts, name="get optimization"))
if not _POSTOPT and not _RANGEIFY: ret.append(RewriteStep(pm_do_optimize, ctx=lambda _: opts, name="optimize ast"))
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
if _POSTOPT or _RANGEIFY: ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))

View File

@@ -0,0 +1,18 @@
from dataclasses import replace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo
from tinygrad.helpers import colored
from tinygrad.codegen.opt.kernel import axis_colors
def rename_sink(s:UOp):
if s.arg is not None and s.arg.name != "test": return None
# get all ranges (sorted)
rngs = sorted([u for u in s.parents if u.op is Ops.RANGE], key=lambda x: x.arg[0:-1])
# add name to kernel
name = "k" + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), axis_colors[x.arg[1]]) for x in rngs])
return s.replace(arg=KernelInfo(name=name) if s.arg is None else replace(s.arg, name=name))
pm_postrange_opt = PatternMatcher([
(UPat(Ops.SINK, name="s"), rename_sink),
])

View File

@@ -140,7 +140,7 @@ DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0),
QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
RANGEIFY = ContextVar("RANGEIFY", 0)
RANGEIFY, POSTOPT = ContextVar("RANGEIFY", 0), ContextVar("POSTOPT", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -2,11 +2,11 @@ from typing import Any
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, RANGEIFY
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint, AxisType
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, identity_element, sint, AxisType
# 0. do some cleanup rewrites, mostly copied from the old stuff
@@ -415,12 +415,8 @@ def split_store(x:UOp):
ctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=ctx, name="kernel split", bottom_up=True)
# get name
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
name = "k"+colored('_', 'BLACK').join(['']+[colored(s.src[0].render(), "WHITE" if s in ret.src[2:] else "red") for s in rng])
# NOTE: the hack for COPY is here
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]
ret = ret.sink() if ret.src[1].op is not Ops.COPY else ret.src[1]
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,()))
return x.as_buf().assign(kernel)