mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pipelining wip
This commit is contained in:
@@ -17,7 +17,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt, pm_add_local_buffers
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt, pm_add_local_buffers, pm_pipeline
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
|
||||
@@ -90,6 +90,10 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
# pipelining
|
||||
ret.append(RewriteStep(pm_pipeline, name="pipeline"))
|
||||
ret.append(RewriteStep(sym, name="pipeline sym"))
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
pass
|
||||
if good_tc_opt:
|
||||
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
if rngs is not None and not AMX:
|
||||
if rngs is not None and not AMX and False:
|
||||
for tc_dim in [1,0]: # attempt to upcast M and N
|
||||
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
|
||||
if szs:
|
||||
@@ -43,6 +43,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
|
||||
if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
|
||||
tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
|
||||
#tk.apply_opt(Opt(OptOps.UNROLL, 0, 2))
|
||||
return tk
|
||||
|
||||
# make a copy so it does not mutate the input
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import math, itertools
|
||||
import math, itertools, functools, operator
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
||||
@@ -359,4 +359,53 @@ def add_local_buffer(x:UOp):
|
||||
|
||||
pm_add_local_buffers = PatternMatcher([
|
||||
(UPat(Ops.LOAD, name="x"), add_local_buffer),
|
||||
])
|
||||
|
||||
def add_pipeline(x:UOp):
|
||||
if x.tag == 1: return None
|
||||
if x.arg[-1] == AxisType.REDUCE:
|
||||
# 3 splits
|
||||
#srcs = (x.const_like(0), x.replace(src=(x.src[0]-2,), tag=1)+1, x.src[0]-1)
|
||||
|
||||
# 4 split
|
||||
rng = x.replace(src=((x.src[0]-2)//2,), tag=1)
|
||||
srcs = (x.const_like(0), rng*2+1, rng*2+2, x.src[0]-1)
|
||||
|
||||
return UOp(Ops.SPLIT, x.dtype, src=srcs, arg=1).simplify()
|
||||
#vec = UOp(Ops.VECTORIZE, x.dtype.vec(3), src=(x.const_like(0), x.replace(src=(x.src[0]-2,), tag=1)+1, x.src[0]-1)).simplify()
|
||||
#return UOp(Ops.UNROLL, x.dtype, src=(vec,), arg=())
|
||||
|
||||
def do_split(x:UOp):
|
||||
splits = [x for x in x.src if x.op is Ops.SPLIT]
|
||||
if len(splits) == 0: return None
|
||||
if x.op is Ops.SINK: return x.replace(src=x.src[0].src)
|
||||
#if x.op is Ops.REDUCE:
|
||||
#assert x.src[0].op is Ops.SPLIT
|
||||
#rr = [y for y in x.src[1].toposort() if y.op is Ops.RANGE][0]
|
||||
#return x.replace(src=(functools.reduce(operator.add, x.src[0].src), rr))
|
||||
uu = []
|
||||
for i in range(len(splits[0].src)):
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
if s.op is Ops.SPLIT:
|
||||
new_srcs.append(s.src[i])
|
||||
else:
|
||||
new_srcs.append(s)
|
||||
uu.append(UOp(x.op, x.dtype, tuple(new_srcs), x.arg, x.tag))
|
||||
if x.op is Ops.STORE and len(splits) == 2:
|
||||
dl = [x for x in uu[0].toposort() if x.op is Ops.DEFINE_LOCAL][0]
|
||||
uu2 = []
|
||||
for i,u in enumerate(uu):
|
||||
uu2.append(u.substitute({dl:dl.replace(arg=(dl.arg, i%2))}))
|
||||
# NOTE: here we have to order the STORES and fix the ranges
|
||||
return UOp(Ops.NOOP, x.dtype, src=tuple(uu2))
|
||||
return UOp(Ops.SPLIT, x.dtype, src=tuple(uu))
|
||||
|
||||
pm_pipeline = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), add_pipeline),
|
||||
# do expansion
|
||||
(UPat(GroupOp.All, name="x", custom_early_reject=set([Ops.SPLIT])), do_split),
|
||||
#(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(Ops.CONST)), name="x"), lambda x: x.replace(src=x.src[:2])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('x'), UPat.var('z'), UPat.var('rr')), name='r'),
|
||||
lambda x,rr,r,z: r.replace(src=(x,z)+tuple([y for y in rr.toposort() if y.op is Ops.RANGE]))),
|
||||
])
|
||||
@@ -21,6 +21,7 @@ class Ops(FastEnum):
|
||||
# create buffer
|
||||
BUFFERIZE = auto()
|
||||
SUBSTITUTE = auto()
|
||||
SPLIT = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
||||
|
||||
Reference in New Issue
Block a user