mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use UOps.WMMA everywhere (#6255)
* add UOps.WMMA_AXIS * delete ReduceOps.WMMA from ops
This commit is contained in:
@@ -42,7 +42,6 @@ class LazyOp:
|
||||
@functools.cached_property
|
||||
def dtype(self) -> DType:
|
||||
if self.op in BufferOps: return self.arg.dtype
|
||||
if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
|
||||
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
||||
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
||||
@functools.cached_property
|
||||
@@ -84,7 +83,7 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op in ReduceOps:
|
||||
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
|
||||
axis = op.arg
|
||||
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
|
||||
st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
|
||||
else:
|
||||
|
||||
@@ -707,7 +707,7 @@ class Kernel:
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (reduceop, wmma_arg[-1]))
|
||||
else:
|
||||
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), (ReduceOps.WMMA, wmma_arg))
|
||||
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
|
||||
return replace(op, src=(ret,), arg=(reduceop, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
@@ -782,7 +782,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
return
|
||||
for x in src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
|
||||
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
|
||||
else:
|
||||
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
|
||||
@@ -110,15 +110,15 @@ class IndependentLowerer:
|
||||
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
|
||||
|
||||
in_uops = tuple(self.to_uop(y) for y in x.src)
|
||||
if x.op is UOps.WMMA:
|
||||
upcast_axes = x.arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=cast(DType, x.dtype).vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
||||
UOp.const(cast(DType, x.dtype).vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
|
||||
if x.op is UOps.REDUCE_AXIS:
|
||||
if x.arg[0] is ReduceOps.WMMA:
|
||||
upcast_axes = x.arg[1][-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=cast(DType, x.dtype).vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
||||
UOp.const(cast(DType, x.dtype).vec(wmma_sz[2]), 0.0)), arg=x.arg[1])
|
||||
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, x.arg[0])]
|
||||
|
||||
@@ -25,7 +25,7 @@ class TernaryOps(Enum):
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
class ReduceOps(Enum):
|
||||
"""A -> B (reduce)"""
|
||||
SUM = auto(); PROD = auto(); MAX = auto(); WMMA = auto() # noqa: E702
|
||||
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
|
||||
class MetaOps(Enum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
||||
|
||||
Reference in New Issue
Block a user