mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
prepare expand to support multiexpand [run_process_replay] (#5503)
This commit is contained in:
@@ -98,7 +98,7 @@ class IndependentLowerer:
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), i))
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
self.ridxs = self.idxs[:]
|
||||
@@ -149,7 +149,7 @@ class IndependentLowerer:
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
|
||||
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
@@ -78,7 +78,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
|
||||
for p in parents:
|
||||
if p.op is UOps.PHI:
|
||||
wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA])
|
||||
parent_expands_for_acc = [x.arg for x in p.parents if x in expands and x.arg not in wmma_reduce_axes]
|
||||
parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes]
|
||||
define_accs.append((p.src[0], parent_expands_for_acc))
|
||||
for x in p.src:
|
||||
children[x].append(p)
|
||||
@@ -114,8 +114,8 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
|
||||
# get replacements by index
|
||||
replacements: Dict[int, List[int]] = {}
|
||||
for r in expands:
|
||||
if r.arg in replacements: assert len(replacements[r.arg]) == len(r.src)
|
||||
else: replacements[r.arg] = list(range(0, len(r.src)))
|
||||
if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src)
|
||||
else: replacements[r.arg[0][0]] = list(range(0, len(r.src)))
|
||||
|
||||
# get nodes on the path from root to the expand node
|
||||
rps = list(itertools.product(*replacements.values()))
|
||||
@@ -125,7 +125,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
|
||||
acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {}
|
||||
for rp in rps:
|
||||
rpk = dict(zip(replacements.keys(), rp))
|
||||
replace = {r:r.src[rpk[r.arg]] for r in expands}
|
||||
replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands}
|
||||
for d, acc_parents in define_accs:
|
||||
acc_index = tuple((d,x,rpk[x]) for x in acc_parents)
|
||||
if acc_index in acc_cache:
|
||||
@@ -149,7 +149,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
|
||||
# ***** reduce+image+contract handling *****
|
||||
|
||||
def expand_wmma(wmma):
|
||||
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg in wmma.arg[-1] or x.arg in wmma.arg[-2])]
|
||||
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])]
|
||||
if len(expands) == 0: return None
|
||||
new_uops = expand_nodes(wmma.sparents, expands, wmma)
|
||||
# TODO: assert that these are all the same. they have to be
|
||||
@@ -161,8 +161,8 @@ def replace_reduce(root):
|
||||
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
|
||||
|
||||
# add other expands for float4. TODO: should be a faster way
|
||||
expand_args = [x.arg for x in expands]
|
||||
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg in expand_args]
|
||||
expand_args = [x.arg[0][0] for x in expands]
|
||||
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args]
|
||||
expands = dedup(expands + new_expands)
|
||||
|
||||
if len(expands):
|
||||
@@ -179,7 +179,7 @@ def replace_reduce(root):
|
||||
|
||||
def replace_contract(root:UOp):
|
||||
parents, dtype = root.parents, cast(DType, root.dtype)
|
||||
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg in root.arg]
|
||||
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg]
|
||||
assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens
|
||||
ret = expand_nodes(parents, expands, root.src[0])
|
||||
if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed?
|
||||
@@ -245,7 +245,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
|
||||
if idx3 is not None: idx = idx + idx3
|
||||
if not idx.divides(len(ex.src)): return None
|
||||
|
||||
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg,))
|
||||
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],))
|
||||
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
|
||||
|
||||
def no_float4_alu(alu):
|
||||
|
||||
Reference in New Issue
Block a user