[FRONTEND] switch absolute imports to relative imports in Triton (#1773)

This commit is contained in:
Izzy Putterman
2023-06-14 16:59:24 -07:00
committed by GitHub
parent 700b3cfd6e
commit 71e21f5797
18 changed files with 209 additions and 173 deletions

View File

@@ -13,8 +13,12 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
import torch
import triton
from triton.common.backend import get_backend
# import triton
# from .. import compile, CompiledKernel
from ..common.backend import get_backend
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
@@ -69,7 +73,7 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
if lhs is None or lhs.__name__ == "triton":
return None
return getattr(lhs, node.attr)
@@ -104,15 +108,15 @@ def version_key():
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(*triton.__path__, 'compiler')
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
@@ -121,7 +125,7 @@ def version_key():
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface(Generic[T]):
@@ -317,6 +321,7 @@ class JITFunction(KernelInterface[T]):
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
@@ -357,7 +362,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
bin = cache[device].get(key, None)
if bin is not None:
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
else:
@@ -375,9 +380,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
@@ -390,7 +395,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
"_device_of": self._device_of,
"_pinned_memory_of": self._pinned_memory_of,
"cache": self.cache,
"triton": triton,
"__spec__": __spec__,
"get_backend": get_backend,
"get_current_device": get_current_device,
"set_current_device": set_current_device}