mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] switch absolute imports to relative imports in Triton (#1773)
This commit is contained in:
@@ -4,7 +4,7 @@ import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ._C.libtriton.triton import runtime
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
@@ -281,7 +281,7 @@ def get_dram_gbps(backend=None, device=None):
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
@@ -295,7 +295,7 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
@@ -398,7 +398,7 @@ def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user