diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 595997eb71..b6c43f39fe 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -370,28 +370,19 @@ def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]: return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])] +@functools.lru_cache(None) +def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int], ...], exclude_args:Tuple[int, ...]) -> List[int]: + return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)] + def do_expand(root:UOp): expands = [x for x in root.src if x.op is UOps.EXPAND] if len(expands) == 0: return None - expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands])))) - if root.op is UOps.WMMA: - # both the reduce and upcast args are not expanded here - dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])]) - expand_args = tuple(x for x in expand_args if x not in dont_expand_args) - else: - dont_expand_args = () - new_srcs: List[UOp] = [] - lrpks = _choices_from_args(dont_expand_args) - for rpk in _choices_from_args(expand_args): - new_src: List[UOp] = [] - for src in root.src: - if src.op is UOps.EXPAND: - lnew_src = tuple(src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks) - # TODO: is this right for UOps.WMMA? when there's more than one, all lnew_src should be the same - new_src.append(lnew_src[0] if len(lnew_src) == 1 or root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, lnew_src, dont_expand_args)) - else: - new_src.append(src) - new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg)) + # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct? + exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else () + expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args) + esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \ + if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src] + new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)] if root.op is UOps.EXPAND: # merge two expands expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args