diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 047c9a756d..ec2c64043b 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -352,13 +352,16 @@ constant_folder = PatternMatcher([ # *** uop expander *** -def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]): +def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int: idx, mul = 0, 1 - for axis,m in arg[::-1]: + for axis,m in args[::-1]: idx += rpk[axis] * mul mul *= m return idx +def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]: + return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])] + def do_expand(root:UOp): if root.op is UOps.REDUCE: if root.src[0].op is not UOps.EXPAND: return None @@ -375,21 +378,16 @@ def do_expand(root:UOp): expand_args = tuple(x for x in expand_args if x not in dont_expand_args) else: dont_expand_args = () - new_srcs = [] - for choices in itertools.product(*[range(x[1]) for x in expand_args]): - rpk = dict(zip([x[0] for x in expand_args], choices)) - new_src = [] + new_srcs: List[UOp] = [] + lrpks = _choices_from_args(dont_expand_args) + for rpk in _choices_from_args(expand_args): + new_src: List[UOp] = [] for src in root.src: if src.op is UOps.EXPAND: - lnew_src = [] - for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]): - lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))} - lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)]) + lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks] if len(dont_expand_args): - if root.op is UOps.WMMA: - new_src.append(lnew_src[0]) # TODO: is this always right? all lnew_src should be the same - else: - new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args)) + # TODO: is this right for UOps.WMMA? all lnew_src should be the same + new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args)) else: assert len(lnew_src) == 1 new_src.append(lnew_src[0]) @@ -397,13 +395,9 @@ def do_expand(root:UOp): new_src.append(src) new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg)) if root.op is UOps.EXPAND: - expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args - assert len(expand_args) == (len(old_expand_args) + len(root.arg)) - new_new_srcs = [] - for choices in itertools.product(*[range(x[1]) for x in expand_args]): - rpk = dict(zip([x[0] for x in expand_args], choices)) - new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)]) - new_srcs = new_new_srcs + expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args + assert len(expand_args) == (len(old_args) + len(root.arg)) + new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)] assert prod([x[1] for x in expand_args]) == len(new_srcs) return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)