[OPS/BLOCKSPARSE] remove unnecessary mask (#1351)

This PR applies a minor patch that removes unnecessary masks in
`_dsd_kernel()`.

### Details

`offs_bn` is defined as follows and not updated after that.
```py
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
```

Because `offs_bn = offs_bn % DS0`, this mask is always `True`.
```py
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
```
This PR removes this mask (as well as explicit `mask=True`).
This commit is contained in:
Shintaro Iwasaki
2023-03-15 19:06:38 -07:00
committed by GitHub
parent c175473bbf
commit 4b774ee4d0

View File

@@ -181,8 +181,8 @@ def _dsd_kernel(
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b * stride_bk