mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] switch absolute imports to relative imports in Triton (#1773)
This commit is contained in:
@@ -13,8 +13,12 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
from triton.common.backend import get_backend
|
||||
# import triton
|
||||
# from .. import compile, CompiledKernel
|
||||
from ..common.backend import get_backend
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
@@ -69,7 +73,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
if lhs is None or lhs.__name__ == "triton":
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -104,15 +108,15 @@ def version_key():
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(*triton.__path__, 'compiler')
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
@@ -121,7 +125,7 @@ def version_key():
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
@@ -317,6 +321,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
@@ -357,7 +362,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
@@ -375,9 +380,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
@@ -390,7 +395,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
"_device_of": self._device_of,
|
||||
"_pinned_memory_of": self._pinned_memory_of,
|
||||
"cache": self.cache,
|
||||
"triton": triton,
|
||||
"__spec__": __spec__,
|
||||
"get_backend": get_backend,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device}
|
||||
|
||||
Reference in New Issue
Block a user