[FRONTEND] switch absolute imports to relative v2 (#1833)

This commit is contained in:
Izzy Putterman
2023-06-25 21:13:12 -07:00
committed by GitHub
parent a3c39d8fbe
commit 3c400e7818
13 changed files with 7 additions and 26 deletions

View File

@@ -85,10 +85,10 @@ def register_backend(device_type: str, backend_cls: type):
def get_backend(device_type: str):
if device_type not in _backends:
device_backend_package_name = f"triton.third_party.{device_type}"
if importlib.util.find_spec(device_backend_package_name):
device_backend_package_name = f"...third_party.{device_type}"
if importlib.util.find_spec(device_backend_package_name, package=__spec__.name):
try:
importlib.import_module(device_backend_package_name)
importlib.import_module(device_backend_package_name, package=__spec__.name)
except Exception:
return None
else:

View File

@@ -11,7 +11,6 @@ from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
# import triton
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
get_shared_memory_size, ir,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import dataclasses
from triton.interpreter import torch_wrapper
from . import torch_wrapper
torch = torch_wrapper.torch

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
# import triton
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext

View File

@@ -6,7 +6,6 @@ from functools import wraps
from typing import Callable, List, Sequence, TypeVar
from .._C.libtriton.triton import ir
# import triton
from ..runtime.jit import jit
from . import semantic

View File

@@ -3,8 +3,8 @@ from __future__ import annotations # remove after python 3.11
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar
from .._C.libtriton.triton import ir
from . import core as tl
from triton._C.libtriton.triton import ir
T = TypeVar('T')

View File

@@ -3,9 +3,6 @@ import torch
from ... import cdiv, heuristics, jit
from ... import language as tl
# import triton
# import language as tl
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)

View File

@@ -1,7 +1,5 @@
import torch
# import triton
# import language as tl
from ... import jit
from ... import language as tl
from ... import next_power_of_2

View File

@@ -1,7 +1,5 @@
import torch
# import triton
# import language as tl
from .. import heuristics, jit
from .. import language as tl
from .. import next_power_of_2

View File

@@ -10,9 +10,6 @@ import torch
from .. import cdiv, jit
from .. import language as tl
# import triton
# import language as tl
@jit
def _fwd_kernel(

View File

@@ -4,9 +4,6 @@ from .. import Config, autotune, cdiv, heuristics, jit
from .. import language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
# import triton
# import language as tl
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()

View File

@@ -2,7 +2,6 @@ import heapq
import torch
# import triton
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver

View File

@@ -11,8 +11,6 @@ from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
# import triton
# from .. import compile, CompiledKernel
from ..common.backend import get_backend
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -71,7 +69,7 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or getattr(lhs, "__name__", "") == "triton":
if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")):
return None
return getattr(lhs, node.attr)
@@ -81,7 +79,7 @@ class DependenciesFinder(ast.NodeVisitor):
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__):
return
assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this"
if func.hash is None: