prepare expand to support multiexpand [run_process_replay] (#5503)

This commit is contained in:
George Hotz
2024-07-15 18:21:24 -07:00
committed by GitHub
parent fd43d33b7d
commit 9d4c3c553c
2 changed files with 11 additions and 11 deletions

View File

@@ -98,7 +98,7 @@ class IndependentLowerer:
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), i))
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),)))
# late indexes (group for reduce)
self.ridxs = self.idxs[:]
@@ -149,7 +149,7 @@ class IndependentLowerer:
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)

View File

@@ -78,7 +78,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
for p in parents:
if p.op is UOps.PHI:
wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA])
parent_expands_for_acc = [x.arg for x in p.parents if x in expands and x.arg not in wmma_reduce_axes]
parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes]
define_accs.append((p.src[0], parent_expands_for_acc))
for x in p.src:
children[x].append(p)
@@ -114,8 +114,8 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
# get replacements by index
replacements: Dict[int, List[int]] = {}
for r in expands:
if r.arg in replacements: assert len(replacements[r.arg]) == len(r.src)
else: replacements[r.arg] = list(range(0, len(r.src)))
if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src)
else: replacements[r.arg[0][0]] = list(range(0, len(r.src)))
# get nodes on the path from root to the expand node
rps = list(itertools.product(*replacements.values()))
@@ -125,7 +125,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {}
for rp in rps:
rpk = dict(zip(replacements.keys(), rp))
replace = {r:r.src[rpk[r.arg]] for r in expands}
replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands}
for d, acc_parents in define_accs:
acc_index = tuple((d,x,rpk[x]) for x in acc_parents)
if acc_index in acc_cache:
@@ -149,7 +149,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
# ***** reduce+image+contract handling *****
def expand_wmma(wmma):
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg in wmma.arg[-1] or x.arg in wmma.arg[-2])]
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])]
if len(expands) == 0: return None
new_uops = expand_nodes(wmma.sparents, expands, wmma)
# TODO: assert that these are all the same. they have to be
@@ -161,8 +161,8 @@ def replace_reduce(root):
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
# add other expands for float4. TODO: should be a faster way
expand_args = [x.arg for x in expands]
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg in expand_args]
expand_args = [x.arg[0][0] for x in expands]
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args]
expands = dedup(expands + new_expands)
if len(expands):
@@ -179,7 +179,7 @@ def replace_reduce(root):
def replace_contract(root:UOp):
parents, dtype = root.parents, cast(DType, root.dtype)
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg in root.arg]
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg]
assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens
ret = expand_nodes(parents, expands, root.src[0])
if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed?
@@ -245,7 +245,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
if idx3 is not None: idx = idx + idx3
if not idx.divides(len(ex.src)): return None
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg,))
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],))
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
def no_float4_alu(alu):