mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPS/BLOCKSPARSE] remove unnecessary mask (#1351)
This PR applies a minor patch that removes unnecessary masks in `_dsd_kernel()`. ### Details `offs_bn` is defined as follows and not updated after that. ```py offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) ``` Because `offs_bn = offs_bn % DS0`, this mask is always `True`. ```py b = tl.load(pb, mask=offs_bn[None, :] < DS0) ``` This PR removes this mask (as well as explicit `mask=True`).
This commit is contained in:
@@ -181,8 +181,8 @@ def _dsd_kernel(
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
a = tl.load(pa, mask=True)
|
||||
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
||||
a = tl.load(pa)
|
||||
b = tl.load(pb)
|
||||
acc += tl.dot(a, b)
|
||||
pa += inc_a
|
||||
pb += inc_b * stride_bk
|
||||
|
||||
Reference in New Issue
Block a user