mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix bug for size 8
This commit is contained in:
@@ -175,7 +175,7 @@ class IndependentLowerer:
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
|
||||
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]//8),))
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
@@ -148,6 +148,8 @@ constant_folder = PatternMatcher([
|
||||
# tensor core cleanups
|
||||
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
|
||||
.name("reduce_allow_any_len"), reduce_before_expand),
|
||||
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(4))).name("expand"),))
|
||||
.name("reduce_allow_any_len"), reduce_before_expand),
|
||||
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
|
||||
.name("reduce_allow_any_len"), reduce_before_expand),
|
||||
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
|
||||
@@ -269,8 +271,6 @@ constant_folder = PatternMatcher([
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
lambda root, val, v0, v1, v2, v3, v4, v5, v6, v7: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3, v4, v5, v6, v7))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
|
||||
@@ -327,7 +327,7 @@ def do_expand(root:UOp):
|
||||
new_src: List[UOp] = []
|
||||
for src in root.src:
|
||||
if src.op is UOps.EXPAND:
|
||||
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})%8] for lrpk in lrpks]
|
||||
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
|
||||
if len(dont_expand_args):
|
||||
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
|
||||
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
@@ -489,11 +489,14 @@ class UOpGraph:
|
||||
|
||||
# do graph rewrite
|
||||
sink = graph_rewrite(sink, self.folder)
|
||||
print(sink)
|
||||
|
||||
# expand
|
||||
UOpGraph.cnt += 1
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder)
|
||||
|
||||
print(sink)
|
||||
|
||||
# for PTX only
|
||||
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user