simple TC change [run_process_replay] (#5671)

* Revert "Revert "this fails too""

This reverts commit 5de43e7073.

* fix dont_expand_args
This commit is contained in:
George Hotz
2024-07-23 20:11:31 -07:00
committed by GitHub
parent 3060e0be4f
commit 918eebb1b1
3 changed files with 8 additions and 7 deletions

View File

@@ -680,18 +680,19 @@ class Kernel:
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
upcast_axis: Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...]] = \
(((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted+1, 8),))
# https://gpuopen.com/learn/wmma_on_rdna3/
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
fix_st2 = None
elif self.opts.device == "METAL":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
upcast_axis = (((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),))
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
upcast_axis = (((self.shape_len-self.upcasted, 8),), ((self.shape_len-self.upcasted+2, 4),), ((self.shape_len-self.upcasted+2, 4),))
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))

View File

@@ -173,10 +173,10 @@ class IndependentLowerer:
if x.op is ReduceOps.WMMA:
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axis[0]),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axis[1]),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)

View File

@@ -327,7 +327,7 @@ def do_expand(root:UOp):
if len(expands) == 0: return None
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
if root.op is UOps.WMMA:
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2])
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
else:
dont_expand_args = ()