[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)

* dds layout now internally re-uses dsd code path for increased code 
* at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks.
 * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros
 * blocksparse softmax now no longer modifies any data in-place
 * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention.
  * unit tests now test backward pass
This commit is contained in:
Philippe Tillet
2022-02-06 18:00:45 -08:00
committed by GitHub
parent 69ff52ea1f
commit 5a8a544d10
4 changed files with 311 additions and 361 deletions

View File

@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
if data is None:
data = torch.randn(shape, dtype=torch.float32, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().float()
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def cutlass_matmul(a, b):
if _cutlass is None:
raise RuntimeError("Cannot find cutlass library")