mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
* Simplified `triton.kernel` API to achieve lower latency:
> .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
> compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
> torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
import torch
|
|
import triton
|
|
import os
|
|
|
|
class _conv(torch.autograd.Function):
|
|
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
|
|
kernel = dict()
|
|
|
|
@staticmethod
|
|
def unpack(IDX, CI, R, S):
|
|
s = IDX % S
|
|
cr = IDX // S
|
|
r = cr % R
|
|
ci = cr // R
|
|
return ci, r, s
|
|
|
|
@staticmethod
|
|
def forward(ctx, a, b, pad, stride):
|
|
# create kernel if necessary
|
|
dtype = a.dtype
|
|
device = a.device
|
|
# shapes
|
|
Z, CI, H, W = a.shape
|
|
_, R, S, CO = b.shape
|
|
P = (H + 2*pad[0] - R)//stride[0] + 1
|
|
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
|
# compile kernel
|
|
if (dtype, device) not in _conv.kernel:
|
|
TK = 16
|
|
defines = {
|
|
'TYPE' : dtype,
|
|
'TM' : [32, 64, 128],
|
|
'TN' : [32, 64, 128],
|
|
'TK' : [TK],
|
|
'TZ' : [1],
|
|
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
|
}
|
|
idx = torch.arange(CI*R*S)
|
|
ci, r, s = _conv.unpack(idx, CI, R, S)
|
|
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
|
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
|
delta = delta.type(torch.int32).cuda()
|
|
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
|
|
delta, kernel = _conv.kernel[dtype]
|
|
# allocate output
|
|
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
|
# enqueue
|
|
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
|
|
pad[0], pad[1], stride[0], stride[1],
|
|
delta.data_ptr(),
|
|
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
|
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
|
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
|
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
|
|
return c
|
|
|
|
conv = _conv.apply |