use UOps.WMMA everywhere (#6255)

* add UOps.WMMA_AXIS

* delete ReduceOps.WMMA from ops
This commit is contained in:
qazal
2024-08-24 03:03:26 +08:00
committed by GitHub
parent 66d0b14a20
commit 0d4887e9df
4 changed files with 12 additions and 13 deletions

View File

@@ -42,7 +42,6 @@ class LazyOp:
@functools.cached_property
def dtype(self) -> DType:
if self.op in BufferOps: return self.arg.dtype
if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
@functools.cached_property
@@ -84,7 +83,7 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op in ReduceOps:
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
axis = op.arg
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
else:

View File

@@ -707,7 +707,7 @@ class Kernel:
# MUL/SUM instead of WMMA
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (reduceop, wmma_arg[-1]))
else:
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), (ReduceOps.WMMA, wmma_arg))
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
return replace(op, src=(ret,), arg=(reduceop, new_reduce_axes)) if new_reduce_axes else ret
if self.group_for_reduces:
@@ -782,7 +782,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
return
for x in src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
else:
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER

View File

@@ -110,15 +110,15 @@ class IndependentLowerer:
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
in_uops = tuple(self.to_uop(y) for y in x.src)
if x.op is UOps.WMMA:
upcast_axes = x.arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=cast(DType, x.dtype).vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
UOp.const(cast(DType, x.dtype).vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
if x.op is UOps.REDUCE_AXIS:
if x.arg[0] is ReduceOps.WMMA:
upcast_axes = x.arg[1][-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=cast(DType, x.dtype).vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
UOp.const(cast(DType, x.dtype).vec(wmma_sz[2]), 0.0)), arg=x.arg[1])
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, x.arg[0])]

View File

@@ -25,7 +25,7 @@ class TernaryOps(Enum):
WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum):
"""A -> B (reduce)"""
SUM = auto(); PROD = auto(); MAX = auto(); WMMA = auto() # noqa: E702
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
class MetaOps(Enum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]