Files
ROCm/python/test/backend/test_device_backend.py
Wang Weihan b27a91a113 [FRONTEND] Enable triton to support register thirdparty backend at runtime (#1643)
This PR intends to provide a mechanism to support a third-party backend
at runtime to generate the backend-specific code.

The mechanism provided a common class to abstract the third-party
backend logic and two essential functions to register and get the
third-party backend at runtime.

- `BaseBackend`: A common class to abstract the third-party backend
logic
- `register_backend`: Register a third-party backend with a given device
type
- `get_backend`: Get the third-party backend with a given device type

Generally, a third-party backend must inherit from `BaseBackend` and
implement all the member functions according to the backend
characteristics. As long as the backend implementation is ready, the
third-party backend can invoke `register_backend` to register it under a
given device. During the kernel compilation and execution, the mechanism
will get the registered backend to generate the kernel and launcher code
for a given device.

This PR added a dummy backend to simulate a third-party backend and
demonstrate the usage.

-
[test_device_backend.py](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1):
To define a third-party backend and register the backend
-
[ExtensionBackend](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R123):
Inherit from the `BaseBackend` and implement some specific logic like
[filter out some compile
stages](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R129-R135)
- [Register the `ExtensionBackend` for
`CPU`](https://github.com/openai/triton/pull/1643/files#diff-bbe4d50624f2d11bf17c878a1ed4d422918c124c182cf9357b993240c385bea1R279)
  
-
[extension_backend.c](https://github.com/openai/triton/pull/1643/files#diff-169c1d08b3a0a7b343cfa3258fbc32b47e0f6c46305a112652fa1bdaaec89d29):
To provide the utility function to load kernel binary and get the
backend properties.
2023-06-09 09:09:59 -07:00

263 lines
8.5 KiB
Python

import functools
import hashlib
import importlib
import os
import shutil
import subprocess
import sysconfig
import tempfile
from pathlib import Path
import setuptools
import torch
import triton
import triton.language as tl
from triton.common.backend import BaseBackend, register_backend
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
from triton.runtime.jit import version_key
def build_for_backend(name, src, srcdir):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
ret = subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = []
include_dirs = [srcdir]
libraries = []
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
class ExtensionUtils:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "extension_backend.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "ext_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend("ext_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("ext_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class ExtensionDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = ExtensionUtils()
class ExtensionBackend(BaseBackend):
stub_so_path = ""
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir", "ttgir"]
filter_out_stages = []
for key, _ in stages.items():
if key not in filter_in_stages:
filter_out_stages.append(key)
for filter_out_key in filter_out_stages:
stages.pop(filter_out_key)
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
metadata["name"] = "extension_backend_name"
def get_driver(self):
return self.driver
def get_stream(self):
return ""
@functools.lru_cache(None)
def get_device_properties(self, device):
return self.driver.utils.get_device_properties()
def get_current_device(self):
return torch.device("cpu")
def set_current_device(self, device):
pass
def get_load_binary_fn(self):
return self.driver.utils.load_binary
def get_kernel_bin(self):
return "ttgir"
def get_architecture_descriptor(self, **kwargs):
return ""
def make_launcher_stub(self, name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src = self._generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend(name, src_path, tmpdir)
with open(so, "rb") as f:
so_path = so_cache_manager.put(f.read(), so_name, binary=True)
type(self).stub_so_path = so_path
return so_path
else:
type(self).stub_so_path = cache_path
return cache_path
def _generate_launcher(self, constants, signature):
# generate glue code
src = """
#define __EXTENSION_BACKEND__
#include <Python.h>
#include <stdio.h>
static PyObject* launch_counter(PyObject* self, PyObject* args) {
static int64_t launch_counter = 0;
launch_counter += 1;
return PyLong_FromLong(launch_counter);
}
static PyObject* launch(PyObject* self, PyObject* args) {
if (PyErr_Occurred()) {
return NULL;
}
launch_counter(self, args);
// return None
Py_INCREF(Py_None);
return Py_None;
}
static PyMethodDef ModuleMethods[] = {
{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
{"launch_counter", launch_counter, METH_VARARGS, "Entry point to get launch counter"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit___triton_launcher(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
return src
def test_dummy_backend():
register_backend("cpu", ExtensionBackend)
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
inp = torch.randn(10)
out = torch.randn(10)
kernel[(10,)](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
assert launch_counter() > 0