mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
import torch
|
|
|
|
from .. import heuristics, jit
|
|
from .. import language as tl
|
|
from .. import next_power_of_2
|
|
|
|
|
|
def num_warps(N):
|
|
if N < 2048:
|
|
return 4
|
|
elif N < 8192:
|
|
return 8
|
|
return 16
|
|
|
|
|
|
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
|
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
|
@jit
|
|
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
|
row = tl.program_id(0)
|
|
cols = tl.arange(0, BLOCK)
|
|
idx = tl.load(IDX + row)
|
|
# pointers to logit and probs
|
|
LOGITS = LOGITS + row * N + cols
|
|
WRIT_PROBS = PROBS + row * N + cols
|
|
READ_PROBS = PROBS + row * N + idx
|
|
# write-back negative log-probs
|
|
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
|
|
logits = logits.to(tl.float32)
|
|
logits = logits - tl.max(logits, 0)
|
|
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
|
|
tl.store(WRIT_PROBS, probs, mask=cols < N)
|
|
# There is a bug in the compiler, which fails to insert a barrier here.
|
|
# We add it explicitly for now. Will be fixed soon.
|
|
tl.debug_barrier()
|
|
# write-back loss
|
|
probs = tl.load(READ_PROBS)
|
|
tl.store(LOSS + row, probs)
|
|
|
|
|
|
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
|
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
|
@jit
|
|
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
|
row = tl.program_id(0)
|
|
cols = tl.arange(0, BLOCK)
|
|
idx = tl.load(IDX + row)
|
|
# pointers to probs
|
|
PROBS = PROBS + row * N + cols
|
|
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
|
# and we have -log(p[k]) stored in PROBS, so this is easy
|
|
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
|
|
probs = tl.exp(probs.to(tl.float32))
|
|
delta = cols == idx
|
|
# write result in-place in PROBS
|
|
dout = tl.load(DPROBS + row)
|
|
din = (probs - delta) * dout
|
|
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
|
|
|
|
|
class _cross_entropy(torch.autograd.Function):
|
|
|
|
@classmethod
|
|
def forward(cls, ctx, logits, indices):
|
|
# make sure we can use triton
|
|
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
|
# make kernel
|
|
device, dtype = logits.device, logits.dtype
|
|
n_cols = logits.shape[-1]
|
|
# run the kernel
|
|
result = torch.empty_like(indices, dtype=dtype, device=device)
|
|
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
|
grid = lambda opt: (logits.numel() // n_cols, )
|
|
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
|
|
# save for backward
|
|
ctx.save_for_backward(neg_logprobs, indices)
|
|
return result
|
|
|
|
@classmethod
|
|
def backward(cls, ctx, dneg_logprobs):
|
|
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
|
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
|
to get p[k], which is most of what we need... neg_logprobs will be
|
|
modified in place to become the gradient we want
|
|
"""
|
|
# load saved tensors
|
|
neg_logprobs, indices = ctx.saved_tensors
|
|
# run the kernel
|
|
# neg_logprobs will be modified in place to become our gradient:
|
|
n_cols = neg_logprobs.shape[-1]
|
|
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
|
|
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
|
|
return neg_logprobs, None
|
|
|
|
|
|
cross_entropy = _cross_entropy.apply
|