mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Improve triton hooks (#1256)
Callback interfaces are not changed, just to record more attributes (i.e., `constants`) and simplify invocations
This commit is contained in:
@@ -1246,16 +1246,13 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
PyObject *hook_ret = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
|
||||
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
|
||||
Py_DECREF(new_args);
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
|
||||
@@ -1264,19 +1261,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject *new_args = NULL;
|
||||
if (hook_ret) {{
|
||||
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
|
||||
}} else {{
|
||||
new_args = PyTuple_Pack(1, compiled_kernel);
|
||||
}}
|
||||
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
|
||||
Py_DECREF(new_args);
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if (hook_ret) {{
|
||||
Py_DECREF(hook_ret);
|
||||
}}
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
@@ -1543,7 +1530,22 @@ arg_type_pattern = {
|
||||
}
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
def _is_jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
return False
|
||||
serialized_constants = {}
|
||||
for constant in constants:
|
||||
if _is_jsonable(constants[constant]):
|
||||
serialized_constants[constant] = constants[constant]
|
||||
return serialized_constants
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
capability = kwargs.get("cc", None)
|
||||
if capability is None:
|
||||
@@ -1616,7 +1618,7 @@ def compile(fn, **kwargs):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "constants": _get_jsonable_constants(constants), "ctime": dict()}
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
@@ -1647,7 +1649,7 @@ def compile(fn, **kwargs):
|
||||
# write-back metadata
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
# return handle to compiled kernel
|
||||
return CompiledKernel(so_path, metadata, asm)
|
||||
return CompiledKernel(fn, so_path, metadata, asm)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
@@ -1656,17 +1658,19 @@ class CompiledKernel:
|
||||
launch_enter_hook = None
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, so_path, metadata, asm):
|
||||
def __init__(self, fn, so_path, metadata, asm):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.fn = fn
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.constants = metadata["constants"]
|
||||
# initialize asm dict
|
||||
self.asm = asm
|
||||
# binaries are lazily initialized
|
||||
|
||||
Reference in New Issue
Block a user