[FA] Upstream FA qk initialization (#328)

This PR replaces qk matrix initialization with
upstream version
This commit is contained in:
Alexander Efimov
2023-09-14 07:34:14 +02:00
committed by GitHub
parent 23465f3416
commit b25557ad5e

View File

@@ -200,10 +200,10 @@ def _bwd_kernel(
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
if CAUSAL:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
# compute dv