mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] switch absolute imports to relative v2 (#1833)
This commit is contained in:
@@ -85,10 +85,10 @@ def register_backend(device_type: str, backend_cls: type):
|
||||
|
||||
def get_backend(device_type: str):
|
||||
if device_type not in _backends:
|
||||
device_backend_package_name = f"triton.third_party.{device_type}"
|
||||
if importlib.util.find_spec(device_backend_package_name):
|
||||
device_backend_package_name = f"...third_party.{device_type}"
|
||||
if importlib.util.find_spec(device_backend_package_name, package=__spec__.name):
|
||||
try:
|
||||
importlib.import_module(device_backend_package_name)
|
||||
importlib.import_module(device_backend_package_name, package=__spec__.name)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,6 @@ from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
# import triton
|
||||
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
|
||||
get_shared_memory_size, ir,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from triton.interpreter import torch_wrapper
|
||||
from . import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# import triton
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
|
||||
@@ -6,7 +6,6 @@ from functools import wraps
|
||||
from typing import Callable, List, Sequence, TypeVar
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
# import triton
|
||||
from ..runtime.jit import jit
|
||||
from . import semantic
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations # remove after python 3.11
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@@ -3,9 +3,6 @@ import torch
|
||||
from ... import cdiv, heuristics, jit
|
||||
from ... import language as tl
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
# Sparse = Dense x Dense (SDD)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import torch
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
from ... import jit
|
||||
from ... import language as tl
|
||||
from ... import next_power_of_2
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import torch
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
from .. import heuristics, jit
|
||||
from .. import language as tl
|
||||
from .. import next_power_of_2
|
||||
|
||||
@@ -10,9 +10,6 @@ import torch
|
||||
from .. import cdiv, jit
|
||||
from .. import language as tl
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
|
||||
@@ -4,9 +4,6 @@ from .. import Config, autotune, cdiv, heuristics, jit
|
||||
from .. import language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@@ -2,7 +2,6 @@ import heapq
|
||||
|
||||
import torch
|
||||
|
||||
# import triton
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
|
||||
@@ -11,8 +11,6 @@ from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
|
||||
# import triton
|
||||
# from .. import compile, CompiledKernel
|
||||
from ..common.backend import get_backend
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -71,7 +69,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or getattr(lhs, "__name__", "") == "triton":
|
||||
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -81,7 +79,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
|
||||
return
|
||||
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
|
||||
if func.hash is None:
|
||||
|
||||
Reference in New Issue
Block a user