diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 8a67e973d5..ffa4a72e73 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -680,18 +680,19 @@ class Kernel: if self.opts.device == "AMD": reduce_axes = [self.shape_len-self.upcasted] - upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1) + upcast_axis: Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...]] = \ + (((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted+1, 8),)) # https://gpuopen.com/learn/wmma_on_rdna3/ fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0))) fix_st2 = None elif self.opts.device == "METAL": reduce_axes = [self.shape_len-self.upcasted] - upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1) + upcast_axis = (((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),)) fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2))) fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3))) elif self.opts.device in {"CUDA", "NV"}: reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1] - upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2) + upcast_axis = (((self.shape_len-self.upcasted, 8),), ((self.shape_len-self.upcasted+2, 4),), ((self.shape_len-self.upcasted+2, 4),)) # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2), ((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5))) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 8717816b7c..d6ebb47f9d 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -173,10 +173,10 @@ class IndependentLowerer: if x.op is ReduceOps.WMMA: wmma_sz, upcast_axis = x.arg[4], x.arg[6] ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=( - UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)), - UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)), + UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axis[0]), + UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axis[1]), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) - return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),)) + return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2]) # NOTE: always using ridxs is fine here return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 51d24598bb..156a4145c4 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -327,7 +327,7 @@ def do_expand(root:UOp): if len(expands) == 0: return None expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands])))) if root.op is UOps.WMMA: - dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2]) + dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])]) expand_args = tuple(x for x in expand_args if x not in dont_expand_args) else: dont_expand_args = ()