mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
start graph rewrite in the scheduler (#6248)
* start graph rewrite in the scheduler * test: enable it * test timings * only fails in multi reduce * more isolated tests
This commit is contained in:
@@ -3,15 +3,17 @@
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import Dict, List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, reshape_uop
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
@@ -1604,5 +1606,41 @@ class TestIndexing(unittest.TestCase):
|
||||
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
|
||||
|
||||
class TestScheduleRewrite(unittest.TestCase):
|
||||
def test_recursive_get_output_st(self):
|
||||
start = time.perf_counter()
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
st = get_output_st(ast.src[0].src[2], {})
|
||||
self.assertEqual(st, ShapeTracker.from_shape((4,)))
|
||||
self.assertLess(time.perf_counter()-start, 1.0)
|
||||
|
||||
def test_recursive_reshape(self):
|
||||
start = time.perf_counter()
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
new_uop = reshape_uop(ast.src[0].src[2], (4, 1), {}, {})
|
||||
self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertLess(time.perf_counter()-start, 1.0)
|
||||
|
||||
def test_uop_sts_reshape(self):
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
a = Tensor([1,2,3,4]).realize()+2
|
||||
ast = a.schedule()[0].ast
|
||||
val = ast.src[0].src[2]
|
||||
ret = get_output_st(val, uop_sts)
|
||||
assert uop_sts[val] == ret == ShapeTracker.from_shape((4,))
|
||||
new_val = reshape_uop(val, (4, 1), uop_sts, {})
|
||||
self.assertNotIn(new_val, uop_sts)
|
||||
|
||||
def test_reshape_noop(self):
|
||||
a = Tensor([1,2,3,4]).realize()+2
|
||||
ast = a.schedule()[0].ast
|
||||
val = ast.src[0].src[2]
|
||||
new_val = reshape_uop(val, (4,), {}, {})
|
||||
self.assertIs(new_val, val)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import sys, pickle, atexit, importlib, contextlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
|
||||
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import BUFFER_UOPS, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -18,6 +18,8 @@ sys.setrecursionlimit(10000)
|
||||
|
||||
# optionally log the ops to disk
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
# use graph rewrite for reduceop fusion
|
||||
AST_REWRITE = getenv("AST_REWRITE", 0)
|
||||
|
||||
# *** ScheduleItem return type ***
|
||||
|
||||
@@ -81,7 +83,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rinfo = reduce_info.get((buf, st))
|
||||
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
|
||||
if AST_REWRITE else reduce_info.get((buf, st))
|
||||
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
# if we are merging the reduce, skip it
|
||||
if rinfo is None:
|
||||
@@ -143,6 +146,38 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
return (buf, st)
|
||||
return cache.setdefault((buf, st), top_reduce)
|
||||
|
||||
# ***** helpers for doing movementops on uops *****
|
||||
|
||||
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> ShapeTracker:
|
||||
if (st:=uop_sts.get(uop)): return st
|
||||
if uop.op in BUFFER_UOPS: return uop.st_arg
|
||||
src_sts = [get_output_st(x, uop_sts) for x in uop.src]
|
||||
assert all_same([x.shape for x in src_sts]), f"inhomogeneous shape from\n{uop}\n{[x.shape for x in src_sts]}"
|
||||
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return st
|
||||
|
||||
def reshape_uop(u:UOp, new_shape:Tuple[sint, ...], uop_sts:Dict[UOp, ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (reshaped:=cache.get(u)): return reshaped
|
||||
if (st:=uop_sts.get(u)) and st.shape == new_shape: return u
|
||||
if u.op is UOps.SHAPETRACKER: return u if u.arg.shape == new_shape else replace(u, arg=u.arg.reshape(new_shape))
|
||||
new_srcs = tuple(reshape_uop(x, new_shape, uop_sts, cache) for x in u.src)
|
||||
cache[u] = reshaped = u if new_srcs == u.src else replace(u, src=new_srcs)
|
||||
return reshaped
|
||||
|
||||
# ***** reduceop fusor *****
|
||||
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
if len(reduceops) == 0: return None
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
rshape = get_output_st(reduceops[0], uop_sts).shape
|
||||
if rshape == root.st_arg.shape: return None
|
||||
return reshape_uop(root, rshape, uop_sts, {})
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.STORE, name="root"), push_reduceop_shape),
|
||||
])
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
@@ -152,18 +187,20 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
|
||||
# push through all movementops between reduceops
|
||||
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
|
||||
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||
# pad all reduceops to the max of each dimension
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
||||
for i,dims in enumerate(shape_dims):
|
||||
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
|
||||
for (r,view),(input_st,axis) in reduce_info.items():
|
||||
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
|
||||
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
|
||||
reduce_info[(r, view)] = (input_st, axis)
|
||||
if not AST_REWRITE:
|
||||
# push through all movementops between reduceops
|
||||
# NOTE: AST_REWRITE does this with graph rewrite
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
|
||||
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||
# pad all reduceops to the max of each dimension
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
||||
for i,dims in enumerate(shape_dims):
|
||||
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
|
||||
for (r,view),(input_st,axis) in reduce_info.items():
|
||||
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
|
||||
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
|
||||
reduce_info[(r, view)] = (input_st, axis)
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
@@ -171,16 +208,19 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
ast: List[UOp] = []
|
||||
inputs: Dict[LazyBuffer, int] = {}
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
||||
output_shape = ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info and not getenv("AST_REWRITE") else out.shape
|
||||
output_st = ShapeTracker.from_shape(output_shape)
|
||||
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0].reshape(output_st.shape)
|
||||
output_st = out.arg[0].reshape(output_shape)
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||
ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src)))
|
||||
sink = UOp(UOps.SINK, None, tuple(ast))
|
||||
if AST_REWRITE:
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
Reference in New Issue
Block a user