mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
simpler expand without dont_expand_args [run_process_replay] (#6230)
* simpler expand without dont_expand_args [run_process_replay] * Revert "simpler expand without dont_expand_args [run_process_replay]" This reverts commit 81693024c097c31e601f1a199a631e9eda0d9638. * exclude_args * why does that fix it * correct fix * _swizzle_args should be fast * add comment * zip is tuples
This commit is contained in:
@@ -370,28 +370,19 @@ def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) ->
|
||||
def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
|
||||
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int], ...], exclude_args:Tuple[int, ...]) -> List[int]:
|
||||
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
||||
|
||||
def do_expand(root:UOp):
|
||||
expands = [x for x in root.src if x.op is UOps.EXPAND]
|
||||
if len(expands) == 0: return None
|
||||
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
|
||||
if root.op is UOps.WMMA:
|
||||
# both the reduce and upcast args are not expanded here
|
||||
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
|
||||
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
|
||||
else:
|
||||
dont_expand_args = ()
|
||||
new_srcs: List[UOp] = []
|
||||
lrpks = _choices_from_args(dont_expand_args)
|
||||
for rpk in _choices_from_args(expand_args):
|
||||
new_src: List[UOp] = []
|
||||
for src in root.src:
|
||||
if src.op is UOps.EXPAND:
|
||||
lnew_src = tuple(src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks)
|
||||
# TODO: is this right for UOps.WMMA? when there's more than one, all lnew_src should be the same
|
||||
new_src.append(lnew_src[0] if len(lnew_src) == 1 or root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, lnew_src, dont_expand_args))
|
||||
else:
|
||||
new_src.append(src)
|
||||
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
|
||||
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
|
||||
expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args)
|
||||
esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \
|
||||
if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src]
|
||||
new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)]
|
||||
if root.op is UOps.EXPAND:
|
||||
# merge two expands
|
||||
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
|
||||
|
||||
Reference in New Issue
Block a user