diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 8a41ff1dd9..6bcab93064 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -98,7 +98,7 @@ class IndependentLowerer: # upcast loops for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): assert isinstance(g, int), "needs to be int to upcast/unroll" - self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), i)) + self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),))) # late indexes (group for reduce) self.ridxs = self.idxs[:] @@ -149,7 +149,7 @@ class IndependentLowerer: UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)), UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) - return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2]) + return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),)) # NOTE: always using ridxs is fine here return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 906787b23f..20253e6464 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -78,7 +78,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]: for p in parents: if p.op is UOps.PHI: wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA]) - parent_expands_for_acc = [x.arg for x in p.parents if x in expands and x.arg not in wmma_reduce_axes] + parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes] define_accs.append((p.src[0], parent_expands_for_acc)) for x in p.src: children[x].append(p) @@ -114,8 +114,8 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]: # get replacements by index replacements: Dict[int, List[int]] = {} for r in expands: - if r.arg in replacements: assert len(replacements[r.arg]) == len(r.src) - else: replacements[r.arg] = list(range(0, len(r.src))) + if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src) + else: replacements[r.arg[0][0]] = list(range(0, len(r.src))) # get nodes on the path from root to the expand node rps = list(itertools.product(*replacements.values())) @@ -125,7 +125,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]: acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {} for rp in rps: rpk = dict(zip(replacements.keys(), rp)) - replace = {r:r.src[rpk[r.arg]] for r in expands} + replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands} for d, acc_parents in define_accs: acc_index = tuple((d,x,rpk[x]) for x in acc_parents) if acc_index in acc_cache: @@ -149,7 +149,7 @@ def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]: # ***** reduce+image+contract handling ***** def expand_wmma(wmma): - expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg in wmma.arg[-1] or x.arg in wmma.arg[-2])] + expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])] if len(expands) == 0: return None new_uops = expand_nodes(wmma.sparents, expands, wmma) # TODO: assert that these are all the same. they have to be @@ -161,8 +161,8 @@ def replace_reduce(root): expands = [x for x in root.src[1:] if x.op is UOps.EXPAND] # add other expands for float4. TODO: should be a faster way - expand_args = [x.arg for x in expands] - new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg in expand_args] + expand_args = [x.arg[0][0] for x in expands] + new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args] expands = dedup(expands + new_expands) if len(expands): @@ -179,7 +179,7 @@ def replace_reduce(root): def replace_contract(root:UOp): parents, dtype = root.parents, cast(DType, root.dtype) - expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg in root.arg] + expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg] assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens ret = expand_nodes(parents, expands, root.src[0]) if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed? @@ -245,7 +245,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype if idx3 is not None: idx = idx + idx3 if not idx.divides(len(ex.src)): return None - new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg,)) + new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],)) return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:]) def no_float4_alu(alu):