this fails too

This commit is contained in:
George Hotz
2024-07-24 02:19:55 +00:00
parent fa14f7b4fd
commit df20c4602a
2 changed files with 7 additions and 6 deletions

View File

@@ -680,18 +680,19 @@ class Kernel:
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
upcast_axis: Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...]] = \
(((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted+1, 8),))
# https://gpuopen.com/learn/wmma_on_rdna3/
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
fix_st2 = None
elif self.opts.device == "METAL":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
upcast_axis = (((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),))
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
upcast_axis = (((self.shape_len-self.upcasted, 8),), ((self.shape_len-self.upcasted+2, 4),), ((self.shape_len-self.upcasted+2, 4),))
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))

View File

@@ -173,10 +173,10 @@ class IndependentLowerer:
if x.op is ReduceOps.WMMA:
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axis[0]),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axis[1]),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)