mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
split ranges (#12411)
* split ranges * simpler * split ranges * range str * fix test * oops * faster * no group 2 * tests * dont_sub_ranges_for_image * revert that
This commit is contained in:
24
test/test_linearizer_failures.py
Normal file
24
test/test_linearizer_failures.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# ruff: noqa: E501
|
||||
import unittest
|
||||
from tinygrad.uop.ops import UOp, Ops, AxisType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_fail_1(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 2), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 32), 2, AxisType.LOOP)
|
||||
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=())
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load()
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=())
|
||||
c8 = c7.index(c3).load()
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
c10 = c0.index(c3).store(c9, c1, c2)
|
||||
ast = c10.sink()
|
||||
get_program(ast)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -45,6 +45,16 @@ class TestRangeifyOpt(unittest.TestCase):
|
||||
x = conv1(x).pad([1,1,1,1])+1
|
||||
x.realize()
|
||||
|
||||
# CPU=1 NOOPT=1 DEBUG=4 RANGEIFY=1 python3 test/test_rangeify.py TestRangeifyOpt.test_matmul_reshaped
|
||||
def test_matmul_reshaped(self):
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
(A@B).reshape(N*N).contiguous().realize()
|
||||
|
||||
def test_reduce_reshapes(self):
|
||||
A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten()
|
||||
A.sum().realize()
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
class TestRangeify(unittest.TestCase):
|
||||
def test_groupnorm(self):
|
||||
|
||||
@@ -198,11 +198,11 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
|
||||
|
||||
def test_valid_with_non_const_rhs(self):
|
||||
ridx0 = Range(0, 2**16)
|
||||
ridx0 = Range(0, 1024)
|
||||
ridx1 = Range(1, 4)
|
||||
ridx2 = Range(2, 4)
|
||||
valid = (ridx0<(ridx1*4 + ridx2))&(ridx0<-1).ne(True)
|
||||
idx = ridx0%1024
|
||||
idx = ridx0
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"r0",
|
||||
|
||||
@@ -18,7 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
@@ -62,6 +62,9 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
# split ranges
|
||||
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
|
||||
|
||||
|
||||
@@ -78,13 +78,13 @@ class Scheduler:
|
||||
for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
"""
|
||||
|
||||
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
|
||||
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[-1] == AxisType.LOOP] if store_rngs else []
|
||||
|
||||
def convert_loop_to_global(self):
|
||||
if not self.opts.has_local: return None
|
||||
|
||||
globalizible_rngs = self._globalizable_rngs()
|
||||
rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs]
|
||||
rng = [x.replace(arg=x.arg[0:-1]+(AxisType.GLOBAL,)) if x in globalizible_rngs else x for x in self.rngs]
|
||||
|
||||
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
|
||||
from tinygrad.uop.symbolic import symbolic_flat, sym, invalid_pat
|
||||
from tinygrad.helpers import partition
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -42,6 +42,29 @@ pm_simplify_ranges = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent),
|
||||
])
|
||||
|
||||
def mark_range_mod(ctx, r:UOp, c:UOp):
|
||||
if r not in ctx and r.src[0].op is Ops.CONST and r.src[0].divides(c.arg) is not None: ctx[r] = c
|
||||
|
||||
def do_substitute(ctx, x: UOp):
|
||||
subs = {}
|
||||
for k,v in ctx.items():
|
||||
if v is not None:
|
||||
subs[k] = k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1]))
|
||||
if not len(subs): return None
|
||||
ret = x.substitute(subs).simplify()
|
||||
ctx.clear()
|
||||
return ret
|
||||
|
||||
def dont_sub_ranges_for_image(ctx, x:UOp):
|
||||
if isinstance(x.src[0].dtype, ImageDType):
|
||||
for s in x.src[1:]: ctx[s] = None
|
||||
|
||||
pm_split_ranges = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod),
|
||||
(UPat(Ops.STORE, name="x"), dont_sub_ranges_for_image),
|
||||
(UPat(Ops.SINK, name="x"), do_substitute),
|
||||
])
|
||||
|
||||
# **** reduce simplification ****
|
||||
|
||||
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Literal, Callable, cast
|
||||
import os, math, sys
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop, range_str
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -158,7 +158,7 @@ class CStyleLanguage(Renderer):
|
||||
# naming
|
||||
prefix = None
|
||||
if u.op is Ops.SPECIAL: r[u] = u.arg
|
||||
elif u.op is Ops.RANGE: r[u] = "ridx"+'_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
|
||||
elif u.op is Ops.RANGE: r[u] = "ridx"+range_str(u)
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.codegen.opt import tc
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.uop.decompositions import xexp2, xlog2
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, sint_to_uop
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, sint_to_uop, range_str
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, AMX
|
||||
|
||||
@@ -102,14 +102,14 @@ base_rewrite = PatternMatcher([
|
||||
|
||||
# range
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
|
||||
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{x.arg[0]} ], [ {ctx[x]}phi, %loop_latch_{x.arg[0]} ]"),
|
||||
f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n"
|
||||
f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
|
||||
f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n"
|
||||
f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
|
||||
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
|
||||
f" br i1 {ctx[x]}, label %loop_body_{range_str(x.src[0])}, label %loop_exit_{range_str(x.src[0])}\nloop_exit_{range_str(x.src[0])}:"),
|
||||
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
|
||||
@@ -44,6 +44,8 @@ def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x)
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
def range_str(u:UOp) -> str: return '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
|
||||
|
||||
# used for UOp and UPat
|
||||
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
||||
def dfs(x:Any, cache:dict):
|
||||
@@ -1111,7 +1113,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg)),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"r{x.arg[0]}" if x.arg[0] >= 0 else f"rm{-x.arg[0]}")),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"r{range_str(x)}")),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
|
||||
|
||||
@@ -165,8 +165,8 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) == 2 and \
|
||||
isinstance(rng.arg[0], int) and isinstance(rng.arg[1], AxisType)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
|
||||
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
|
||||
|
||||
@@ -7,7 +7,7 @@ from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -80,7 +80,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
try:
|
||||
if len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
label += f"\n({','.join([colored(range_str(x), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
|
||||
Reference in New Issue
Block a user