share REDUCE_ALU in multi and schedule [run_process_replay] (#6266)

This commit is contained in:
qazal
2024-08-24 21:16:38 +08:00
committed by GitHub
parent 1dc6040877
commit 1b4ad982e5
3 changed files with 6 additions and 4 deletions

View File

@@ -2,7 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field, replace
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import BUFFER_UOPS, BinaryOps, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
@@ -86,7 +86,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
if AST_REWRITE else reduce_info.get((buf, st))
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, buf.op)]
alu_op = REDUCE_ALU[cast(ReduceOps, buf.op)]
# if we are merging the reduce, skip it
if rinfo is None:
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"

View File

@@ -3,14 +3,14 @@ from typing import Optional, Union, Any, Tuple, List, Dict
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.dtype import DType, ConstType
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.ops import REDUCE_ALU, BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[op]
bop = REDUCE_ALU[op]
n_lbs, dim = len(lbs), prod(lbs[0].shape)
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)

View File

@@ -33,6 +33,8 @@ Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
# do not preserve f(0) = 0
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered