mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
postrange boilerplate work (#11881)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY, POSTOPT
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -18,6 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
@@ -45,10 +46,10 @@ rewrites_for_linearizer = [
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value, POSTOPT.value)
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL, _RANGEIFY, _POSTOPT) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
@@ -56,12 +57,14 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
ret.extend(rewrites_for_views)
|
||||
|
||||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_get_optimization, ctx=lambda _: opts, name="get optimization"))
|
||||
ret.append(RewriteStep(pm_do_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
if not _RANGEIFY: ret.append(RewriteStep(pm_get_optimization, ctx=lambda _: opts, name="get optimization"))
|
||||
if not _POSTOPT and not _RANGEIFY: ret.append(RewriteStep(pm_do_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
if _POSTOPT or _RANGEIFY: ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
||||
|
||||
|
||||
18
tinygrad/codegen/opt/postrange.py
Normal file
18
tinygrad/codegen/opt/postrange.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dataclasses import replace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.codegen.opt.kernel import axis_colors
|
||||
|
||||
def rename_sink(s:UOp):
|
||||
if s.arg is not None and s.arg.name != "test": return None
|
||||
|
||||
# get all ranges (sorted)
|
||||
rngs = sorted([u for u in s.parents if u.op is Ops.RANGE], key=lambda x: x.arg[0:-1])
|
||||
|
||||
# add name to kernel
|
||||
name = "k" + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), axis_colors[x.arg[1]]) for x in rngs])
|
||||
return s.replace(arg=KernelInfo(name=name) if s.arg is None else replace(s.arg, name=name))
|
||||
|
||||
pm_postrange_opt = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="s"), rename_sink),
|
||||
])
|
||||
@@ -140,7 +140,7 @@ DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0),
|
||||
QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
|
||||
RANGEIFY = ContextVar("RANGEIFY", 0)
|
||||
RANGEIFY, POSTOPT = ContextVar("RANGEIFY", 0), ContextVar("POSTOPT", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, RANGEIFY
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint, AxisType
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, identity_element, sint, AxisType
|
||||
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
@@ -415,12 +415,8 @@ def split_store(x:UOp):
|
||||
ctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=ctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# get name
|
||||
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
name = "k"+colored('_', 'BLACK').join(['']+[colored(s.src[0].render(), "WHITE" if s in ret.src[2:] else "red") for s in rng])
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]
|
||||
ret = ret.sink() if ret.src[1].op is not Ops.COPY else ret.src[1]
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,()))
|
||||
return x.as_buf().assign(kernel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user