pipelining wip

This commit is contained in:
George Hotz
2025-10-06 12:31:20 +08:00
parent df1b379a36
commit fdc0489e18
4 changed files with 58 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt, pm_add_local_buffers
from tinygrad.codegen.opt.postrange import pm_postrange_opt, pm_add_local_buffers, pm_pipeline
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
@@ -90,6 +90,10 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
# pipelining
ret.append(RewriteStep(pm_pipeline, name="pipeline"))
ret.append(RewriteStep(sym, name="pipeline sym"))
# add gpu dims (late). this works after devectorize, but it's faster here
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))

View File

@@ -35,7 +35,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
pass
if good_tc_opt:
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if rngs is not None and not AMX:
if rngs is not None and not AMX and False:
for tc_dim in [1,0]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
if szs:
@@ -43,6 +43,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
#tk.apply_opt(Opt(OptOps.UNROLL, 0, 2))
return tk
# make a copy so it does not mutate the input

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import math, itertools
import math, itertools, functools, operator
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
@@ -359,4 +359,53 @@ def add_local_buffer(x:UOp):
pm_add_local_buffers = PatternMatcher([
(UPat(Ops.LOAD, name="x"), add_local_buffer),
])
def add_pipeline(x:UOp):
if x.tag == 1: return None
if x.arg[-1] == AxisType.REDUCE:
# 3 splits
#srcs = (x.const_like(0), x.replace(src=(x.src[0]-2,), tag=1)+1, x.src[0]-1)
# 4 split
rng = x.replace(src=((x.src[0]-2)//2,), tag=1)
srcs = (x.const_like(0), rng*2+1, rng*2+2, x.src[0]-1)
return UOp(Ops.SPLIT, x.dtype, src=srcs, arg=1).simplify()
#vec = UOp(Ops.VECTORIZE, x.dtype.vec(3), src=(x.const_like(0), x.replace(src=(x.src[0]-2,), tag=1)+1, x.src[0]-1)).simplify()
#return UOp(Ops.UNROLL, x.dtype, src=(vec,), arg=())
def do_split(x:UOp):
splits = [x for x in x.src if x.op is Ops.SPLIT]
if len(splits) == 0: return None
if x.op is Ops.SINK: return x.replace(src=x.src[0].src)
#if x.op is Ops.REDUCE:
#assert x.src[0].op is Ops.SPLIT
#rr = [y for y in x.src[1].toposort() if y.op is Ops.RANGE][0]
#return x.replace(src=(functools.reduce(operator.add, x.src[0].src), rr))
uu = []
for i in range(len(splits[0].src)):
new_srcs = []
for s in x.src:
if s.op is Ops.SPLIT:
new_srcs.append(s.src[i])
else:
new_srcs.append(s)
uu.append(UOp(x.op, x.dtype, tuple(new_srcs), x.arg, x.tag))
if x.op is Ops.STORE and len(splits) == 2:
dl = [x for x in uu[0].toposort() if x.op is Ops.DEFINE_LOCAL][0]
uu2 = []
for i,u in enumerate(uu):
uu2.append(u.substitute({dl:dl.replace(arg=(dl.arg, i%2))}))
# NOTE: here we have to order the STORES and fix the ranges
return UOp(Ops.NOOP, x.dtype, src=tuple(uu2))
return UOp(Ops.SPLIT, x.dtype, src=tuple(uu))
pm_pipeline = PatternMatcher([
(UPat(Ops.RANGE, name="x"), add_pipeline),
# do expansion
(UPat(GroupOp.All, name="x", custom_early_reject=set([Ops.SPLIT])), do_split),
#(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(Ops.CONST)), name="x"), lambda x: x.replace(src=x.src[:2])),
(UPat(Ops.STORE, src=(UPat.var('x'), UPat.var('z'), UPat.var('rr')), name='r'),
lambda x,rr,r,z: r.replace(src=(x,z)+tuple([y for y in rr.toposort() if y.op is Ops.RANGE]))),
])

View File

@@ -21,6 +21,7 @@ class Ops(FastEnum):
# create buffer
BUFFERIZE = auto()
SUBSTITUTE = auto()
SPLIT = auto()
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702