fix bug for size 8

This commit is contained in:
p4sscode
2024-07-21 15:47:24 -03:00
parent 3ef0904bd9
commit cdb3f5df85
2 changed files with 7 additions and 4 deletions

View File

@@ -175,7 +175,7 @@ class IndependentLowerer:
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]//8),))
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)

View File

@@ -148,6 +148,8 @@ constant_folder = PatternMatcher([
# tensor core cleanups
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(4))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
@@ -269,8 +271,6 @@ constant_folder = PatternMatcher([
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
lambda root, val, v0, v1, v2, v3, v4, v5, v6, v7: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3, v4, v5, v6, v7))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
@@ -327,7 +327,7 @@ def do_expand(root:UOp):
new_src: List[UOp] = []
for src in root.src:
if src.op is UOps.EXPAND:
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})%8] for lrpk in lrpks]
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
if len(dont_expand_args):
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
@@ -489,11 +489,14 @@ class UOpGraph:
# do graph rewrite
sink = graph_rewrite(sink, self.folder)
print(sink)
# expand
UOpGraph.cnt += 1
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder)
print(sink)
# for PTX only
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)