diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index d4987c6a77..83b947405b 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -175,7 +175,7 @@ class IndependentLowerer: UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)), UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) - return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),)) + return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]//8),)) # NOTE: always using ridxs is fine here return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4539dac8b3..74f20e12e6 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -148,6 +148,8 @@ constant_folder = PatternMatcher([ # tensor core cleanups (UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),)) .name("reduce_allow_any_len"), reduce_before_expand), + (UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(4))).name("expand"),)) + .name("reduce_allow_any_len"), reduce_before_expand), (UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),)) .name("reduce_allow_any_len"), reduce_before_expand), (UOp.var("add") + UOp(UOps.WMMA).name("wmma"), @@ -269,8 +271,6 @@ constant_folder = PatternMatcher([ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))), lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)), # VECTORIZE-PHI-GEP -> PHI-VECTORIZE - (UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"), - lambda root, val, v0, v1, v2, v3, v4, v5, v6, v7: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3, v4, v5, v6, v7))))), (UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"), lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))), (UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"), @@ -327,7 +327,7 @@ def do_expand(root:UOp): new_src: List[UOp] = [] for src in root.src: if src.op is UOps.EXPAND: - lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})%8] for lrpk in lrpks] + lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks] if len(dont_expand_args): # TODO: is this right for UOps.WMMA? all lnew_src should be the same new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args)) @@ -489,11 +489,14 @@ class UOpGraph: # do graph rewrite sink = graph_rewrite(sink, self.folder) + print(sink) # expand UOpGraph.cnt += 1 if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder) + print(sink) + # for PTX only if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)