clean up expand function [run_process_replay] (#5538)

* clean up expand function [run_process_replay]

* lil cleaner

* add a type
This commit is contained in:
George Hotz
2024-07-17 15:02:00 -07:00
committed by GitHub
parent 61ee02e93d
commit a6e70f8a71

View File

@@ -352,13 +352,16 @@ constant_folder = PatternMatcher([
# *** uop expander ***
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
idx, mul = 0, 1
for axis,m in arg[::-1]:
for axis,m in args[::-1]:
idx += rpk[axis] * mul
mul *= m
return idx
def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
def do_expand(root:UOp):
if root.op is UOps.REDUCE:
if root.src[0].op is not UOps.EXPAND: return None
@@ -375,21 +378,16 @@ def do_expand(root:UOp):
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
else:
dont_expand_args = ()
new_srcs = []
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
rpk = dict(zip([x[0] for x in expand_args], choices))
new_src = []
new_srcs: List[UOp] = []
lrpks = _choices_from_args(dont_expand_args)
for rpk in _choices_from_args(expand_args):
new_src: List[UOp] = []
for src in root.src:
if src.op is UOps.EXPAND:
lnew_src = []
for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]):
lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))}
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
if len(dont_expand_args):
if root.op is UOps.WMMA:
new_src.append(lnew_src[0]) # TODO: is this always right? all lnew_src should be the same
else:
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
else:
assert len(lnew_src) == 1
new_src.append(lnew_src[0])
@@ -397,13 +395,9 @@ def do_expand(root:UOp):
new_src.append(src)
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
if root.op is UOps.EXPAND:
expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args
assert len(expand_args) == (len(old_expand_args) + len(root.arg))
new_new_srcs = []
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
rpk = dict(zip([x[0] for x in expand_args], choices))
new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)])
new_srcs = new_new_srcs
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
assert len(expand_args) == (len(old_args) + len(root.arg))
new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
assert prod([x[1] for x in expand_args]) == len(new_srcs)
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)