mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FA] Upstream FA qk initialization (#328)
This PR replaces qk matrix initialization with upstream version
This commit is contained in:
@@ -200,10 +200,10 @@ def _bwd_kernel(
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
if CAUSAL:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
else:
|
||||
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
|
||||
# compute dv
|
||||
|
||||
Reference in New Issue
Block a user