start graph rewrite in the scheduler (#6248)

* start graph rewrite in the scheduler

* test: enable it

* test timings

* only fails in multi reduce

* more isolated tests
This commit is contained in:
qazal
2024-08-23 18:15:55 +08:00
committed by GitHub
parent 75700edf73
commit 78d6bd8b41
2 changed files with 98 additions and 20 deletions

View File

@@ -3,15 +3,17 @@
# NOTE: this has overlap with external_test_opt.py
import unittest
import time
import numpy as np
from typing import List, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.schedule import create_schedule, get_output_st, reshape_uop
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported, Context
from tinygrad.lazy import LazyBuffer, view_supported_devices
@@ -1604,5 +1606,41 @@ class TestIndexing(unittest.TestCase):
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
class TestScheduleRewrite(unittest.TestCase):
def test_recursive_get_output_st(self):
start = time.perf_counter()
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
st = get_output_st(ast.src[0].src[2], {})
self.assertEqual(st, ShapeTracker.from_shape((4,)))
self.assertLess(time.perf_counter()-start, 1.0)
def test_recursive_reshape(self):
start = time.perf_counter()
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
new_uop = reshape_uop(ast.src[0].src[2], (4, 1), {}, {})
self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(time.perf_counter()-start, 1.0)
def test_uop_sts_reshape(self):
uop_sts: Dict[UOp, ShapeTracker] = {}
a = Tensor([1,2,3,4]).realize()+2
ast = a.schedule()[0].ast
val = ast.src[0].src[2]
ret = get_output_st(val, uop_sts)
assert uop_sts[val] == ret == ShapeTracker.from_shape((4,))
new_val = reshape_uop(val, (4, 1), uop_sts, {})
self.assertNotIn(new_val, uop_sts)
def test_reshape_noop(self):
a = Tensor([1,2,3,4]).realize()+2
ast = a.schedule()[0].ast
val = ast.src[0].src[2]
new_val = reshape_uop(val, (4,), {}, {})
self.assertIs(new_val, val)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,11 +1,11 @@
import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.ops import BUFFER_UOPS, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
from tinygrad.lazy import LazyBuffer
@@ -18,6 +18,8 @@ sys.setrecursionlimit(10000)
# optionally log the ops to disk
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
# use graph rewrite for reduceop fusion
AST_REWRITE = getenv("AST_REWRITE", 0)
# *** ScheduleItem return type ***
@@ -81,7 +83,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
rinfo = reduce_info.get((buf, st))
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
if AST_REWRITE else reduce_info.get((buf, st))
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
# if we are merging the reduce, skip it
if rinfo is None:
@@ -143,6 +146,38 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
return (buf, st)
return cache.setdefault((buf, st), top_reduce)
# ***** helpers for doing movementops on uops *****
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> ShapeTracker:
if (st:=uop_sts.get(uop)): return st
if uop.op in BUFFER_UOPS: return uop.st_arg
src_sts = [get_output_st(x, uop_sts) for x in uop.src]
assert all_same([x.shape for x in src_sts]), f"inhomogeneous shape from\n{uop}\n{[x.shape for x in src_sts]}"
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
return st
def reshape_uop(u:UOp, new_shape:Tuple[sint, ...], uop_sts:Dict[UOp, ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (reshaped:=cache.get(u)): return reshaped
if (st:=uop_sts.get(u)) and st.shape == new_shape: return u
if u.op is UOps.SHAPETRACKER: return u if u.arg.shape == new_shape else replace(u, arg=u.arg.reshape(new_shape))
new_srcs = tuple(reshape_uop(x, new_shape, uop_sts, cache) for x in u.src)
cache[u] = reshaped = u if new_srcs == u.src else replace(u, src=new_srcs)
return reshaped
# ***** reduceop fusor *****
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
if len(reduceops) == 0: return None
uop_sts: Dict[UOp, ShapeTracker] = {}
rshape = get_output_st(reduceops[0], uop_sts).shape
if rshape == root.st_arg.shape: return None
return reshape_uop(root, rshape, uop_sts, {})
reduceop_fusor = PatternMatcher([
(UPat(UOps.STORE, name="root"), push_reduceop_shape),
])
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
@@ -152,18 +187,20 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
# push through all movementops between reduceops
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
for i,dims in enumerate(shape_dims):
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
for (r,view),(input_st,axis) in reduce_info.items():
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
reduce_info[(r, view)] = (input_st, axis)
if not AST_REWRITE:
# push through all movementops between reduceops
# NOTE: AST_REWRITE does this with graph rewrite
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
for i,dims in enumerate(shape_dims):
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
for (r,view),(input_st,axis) in reduce_info.items():
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
reduce_info[(r, view)] = (input_st, axis)
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
@@ -171,16 +208,19 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
output_shape = ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info and not getenv("AST_REWRITE") else out.shape
output_st = ShapeTracker.from_shape(output_shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0].reshape(output_st.shape)
output_st = out.arg[0].reshape(output_shape)
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src)))
sink = UOp(UOps.SINK, None, tuple(ast))
if AST_REWRITE:
sink = graph_rewrite(sink, reduceop_fusor)
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
# *** DAG creation: decide which LazyBuffers should realize ***