Files
ROCm/python/triton/runtime/driver.py
Justin Lebar df08301e76 Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
2023-11-02 20:44:17 -07:00

192 lines
5.4 KiB
Python

import abc
import hashlib
import os
import tempfile
from pathlib import Path
from ..common.build import _build
from .cache import get_cache_manager
class DriverBase(metaclass=abc.ABCMeta):
CUDA = 0
HIP = 1
@staticmethod
def third_party_dir():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
def __init__(self) -> None:
pass
# -----------------------------
# CUDA
# -----------------------------
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "cuda_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("cuda_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
self.CUtensorMapDataType = mod.CUtensorMapDataType
self.CUtensorMapInterleave = mod.CUtensorMapInterleave
self.CUtensorMapSwizzle = mod.CUtensorMapSwizzle
self.CUtensorMapL2promotion = mod.CUtensorMapL2promotion
self.CUtensorMapFloatOOBfill = mod.CUtensorMapFloatOOBfill
self.cuTensorMapEncodeTiled = mod.cuTensorMapEncodeTiled
self.cuMemAlloc = mod.cuMemAlloc
self.cuMemcpyHtoD = mod.cuMemcpyHtoD
self.cuMemFree = mod.cuMemFree
class CudaDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CudaDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = CudaUtils()
self.backend = self.CUDA
# -----------------------------
# HIP
# -----------------------------
class HIPUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(HIPUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "backends", "hip.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "hip_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("hip_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("hip_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class HIPDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(HIPDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = HIPUtils()
self.backend = self.HIP
class UnsupportedDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(UnsupportedDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = None
self.backend = None
# -----------------------------
# Driver
# -----------------------------
class LazyProxy:
def __init__(self, init_fn):
self._init_fn = init_fn
self._obj = None
def _initialize_obj(self):
if self._obj is None:
self._obj = self._init_fn()
def __getattr__(self, name):
self._initialize_obj()
return getattr(self._obj, name)
def __setattr__(self, name, value):
if name in ["_init_fn", "_obj"]:
super().__setattr__(name, value)
else:
self._initialize_obj()
setattr(self._obj, name, value)
def __delattr__(self, name):
self._initialize_obj()
delattr(self._obj, name)
def __repr__(self):
if self._obj is None:
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
return repr(self._obj)
def __str__(self):
self._initialize_obj()
return str(self._obj)
def initialize_driver():
import torch
if torch.version.hip is not None:
return HIPDriver()
elif torch.cuda.is_available():
return CudaDriver()
else:
return UnsupportedDriver()
driver = LazyProxy(initialize_driver)