mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code * at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks. * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros * blocksparse softmax now no longer modifies any data in-place * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention. * unit tests now test backward pass
This commit is contained in:
@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().float()
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||
return ref_ret, tri_ret
|
||||
|
||||
|
||||
def cutlass_matmul(a, b):
|
||||
if _cutlass is None:
|
||||
raise RuntimeError("Cannot find cutlass library")
|
||||
|
||||
Reference in New Issue
Block a user