mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
* Simplified `triton.kernel` API to achieve lower latency:
> .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
> compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
> torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
8 lines
264 B
C
8 lines
264 B
C
__global__ void forward(TYPE* X, TYPE* Y) {
|
|
int pid = get_program_id(0);
|
|
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
|
float x[BLOCK] = *(X + off);
|
|
float shifted[BLOCK] = exp(x - x[max]);
|
|
float sum = shifted[+];
|
|
*(Y + off) = shifted / sum;
|
|
}
|