Files
ROCm/python/triton/ops/conv.py
Philippe Tillet 269ebc12e5 [PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
2021-07-27 12:38:48 -07:00

57 lines
1.9 KiB
Python

import torch
import triton
import os
class _conv(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride):
# create kernel if necessary
dtype = a.dtype
device = a.device
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if (dtype, device) not in _conv.kernel:
TK = 16
defines = {
'TYPE' : dtype,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta.data_ptr(),
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply