mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
simple TC change [run_process_replay] (#5671)
* Revert "Revert "this fails too""
This reverts commit 5de43e7073.
* fix dont_expand_args
This commit is contained in:
@@ -680,18 +680,19 @@ class Kernel:
|
||||
|
||||
if self.opts.device == "AMD":
|
||||
reduce_axes = [self.shape_len-self.upcasted]
|
||||
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
|
||||
upcast_axis: Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...]] = \
|
||||
(((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted, 16),), ((self.shape_len-self.upcasted+1, 8),))
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
|
||||
fix_st2 = None
|
||||
elif self.opts.device == "METAL":
|
||||
reduce_axes = [self.shape_len-self.upcasted]
|
||||
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1, self.shape_len-self.upcasted+1)
|
||||
upcast_axis = (((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),), ((self.shape_len-self.upcasted+1, 2),))
|
||||
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
|
||||
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
|
||||
elif self.opts.device in {"CUDA", "NV"}:
|
||||
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
|
||||
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
|
||||
upcast_axis = (((self.shape_len-self.upcasted, 8),), ((self.shape_len-self.upcasted+2, 4),), ((self.shape_len-self.upcasted+2, 4),))
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,4), (2,2,2,2,2,2),
|
||||
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
|
||||
|
||||
@@ -173,10 +173,10 @@ class IndependentLowerer:
|
||||
if x.op is ReduceOps.WMMA:
|
||||
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
|
||||
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axis[0]),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axis[1]),
|
||||
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
@@ -327,7 +327,7 @@ def do_expand(root:UOp):
|
||||
if len(expands) == 0: return None
|
||||
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
|
||||
if root.op is UOps.WMMA:
|
||||
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2])
|
||||
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
|
||||
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
|
||||
else:
|
||||
dont_expand_args = ()
|
||||
|
||||
Reference in New Issue
Block a user