mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
clean up expand function [run_process_replay] (#5538)
* clean up expand function [run_process_replay] * lil cleaner * add a type
This commit is contained in:
@@ -352,13 +352,16 @@ constant_folder = PatternMatcher([
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
|
||||
def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
|
||||
idx, mul = 0, 1
|
||||
for axis,m in arg[::-1]:
|
||||
for axis,m in args[::-1]:
|
||||
idx += rpk[axis] * mul
|
||||
mul *= m
|
||||
return idx
|
||||
|
||||
def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
|
||||
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
||||
|
||||
def do_expand(root:UOp):
|
||||
if root.op is UOps.REDUCE:
|
||||
if root.src[0].op is not UOps.EXPAND: return None
|
||||
@@ -375,21 +378,16 @@ def do_expand(root:UOp):
|
||||
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
|
||||
else:
|
||||
dont_expand_args = ()
|
||||
new_srcs = []
|
||||
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
|
||||
rpk = dict(zip([x[0] for x in expand_args], choices))
|
||||
new_src = []
|
||||
new_srcs: List[UOp] = []
|
||||
lrpks = _choices_from_args(dont_expand_args)
|
||||
for rpk in _choices_from_args(expand_args):
|
||||
new_src: List[UOp] = []
|
||||
for src in root.src:
|
||||
if src.op is UOps.EXPAND:
|
||||
lnew_src = []
|
||||
for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]):
|
||||
lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))}
|
||||
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
|
||||
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
|
||||
if len(dont_expand_args):
|
||||
if root.op is UOps.WMMA:
|
||||
new_src.append(lnew_src[0]) # TODO: is this always right? all lnew_src should be the same
|
||||
else:
|
||||
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
|
||||
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
else:
|
||||
assert len(lnew_src) == 1
|
||||
new_src.append(lnew_src[0])
|
||||
@@ -397,13 +395,9 @@ def do_expand(root:UOp):
|
||||
new_src.append(src)
|
||||
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
|
||||
if root.op is UOps.EXPAND:
|
||||
expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args
|
||||
assert len(expand_args) == (len(old_expand_args) + len(root.arg))
|
||||
new_new_srcs = []
|
||||
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
|
||||
rpk = dict(zip([x[0] for x in expand_args], choices))
|
||||
new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)])
|
||||
new_srcs = new_new_srcs
|
||||
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
|
||||
assert len(expand_args) == (len(old_args) + len(root.arg))
|
||||
new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
|
||||
assert prod([x[1] for x in expand_args]) == len(new_srcs)
|
||||
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user