simpler expand without dont_expand_args [run_process_replay] (#6230)

* simpler expand without dont_expand_args [run_process_replay]

* Revert "simpler expand without dont_expand_args [run_process_replay]"

This reverts commit 81693024c097c31e601f1a199a631e9eda0d9638.

* exclude_args

* why does that fix it

* correct fix

* _swizzle_args should be fast

* add comment

* zip is tuples
This commit is contained in:
George Hotz
2024-08-21 17:48:45 -07:00
committed by GitHub
parent 78c94abe9c
commit 5cdec79469

View File

@@ -370,28 +370,19 @@ def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) ->
def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
@functools.lru_cache(None)
def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int], ...], exclude_args:Tuple[int, ...]) -> List[int]:
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
def do_expand(root:UOp):
expands = [x for x in root.src if x.op is UOps.EXPAND]
if len(expands) == 0: return None
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
if root.op is UOps.WMMA:
# both the reduce and upcast args are not expanded here
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
else:
dont_expand_args = ()
new_srcs: List[UOp] = []
lrpks = _choices_from_args(dont_expand_args)
for rpk in _choices_from_args(expand_args):
new_src: List[UOp] = []
for src in root.src:
if src.op is UOps.EXPAND:
lnew_src = tuple(src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks)
# TODO: is this right for UOps.WMMA? when there's more than one, all lnew_src should be the same
new_src.append(lnew_src[0] if len(lnew_src) == 1 or root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, lnew_src, dont_expand_args))
else:
new_src.append(src)
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args)
esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \
if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src]
new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)]
if root.op is UOps.EXPAND:
# merge two expands
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args