mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
more scheduler graph_rewrite cleanups [run_process_replay] (#6479)
This commit is contained in:
@@ -6,7 +6,7 @@ from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOp
|
||||
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -90,6 +90,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
|
||||
# ** AST graph rewrite: UOp with SWIZZLE (movementops) -> UOp we can index **
|
||||
|
||||
# ***** helpers for doing movementops on uops *****
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
@@ -131,12 +134,6 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
assert prod(swizzle.arg.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzles = [x for x in root.src if x.op is UOps.SWIZZLE]
|
||||
if len(swizzles) == 0: return None
|
||||
@@ -147,6 +144,12 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
|
||||
return ret if ret.op is UOps.STORE else ret.swizzle(ShapeTracker.from_shape(sw_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
|
||||
@@ -157,15 +160,15 @@ reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
st_uop = ShapeTracker.from_shape(out.arg).to_uop()
|
||||
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop))
|
||||
wr = UOp(UOps.STORE, dtypes.void, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd))
|
||||
return [LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs])]
|
||||
return LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs])
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])]
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
@@ -173,18 +176,17 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
ast: List[UOp] = []
|
||||
inputs: Dict[LazyBuffer, int] = {}
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(out.shape)
|
||||
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0].reshape(out.shape)
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
var_vals.update(vv)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(ast))
|
||||
if AST_REWRITE: sink = graph_rewrite(sink, reduceop_fusor)
|
||||
return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))]
|
||||
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
@@ -363,7 +365,7 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
"""create a graph for realizing the outputs"""
|
||||
output_groups, realizes, assign_targets = _get_output_groups(outs)
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
|
||||
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user