diff --git a/autogen_stubs.sh b/autogen_stubs.sh index bec7fa8c68..3a9d8a9297 100755 --- a/autogen_stubs.sh +++ b/autogen_stubs.sh @@ -41,11 +41,15 @@ generate_hip() { #sed -i "s\import ctypes\import ctypes, ctypes.util\g" $BASE/hip.py #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libhiprtc.so')\ctypes.CDLL(ctypes.util.find_library('hiprtc'))\g" $BASE/hip.py #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libamdhip64.so')\ctypes.CDLL(ctypes.util.find_library('amdhip64'))\g" $BASE/hip.py + sed -i "s\import ctypes\import ctypes, os\g" $BASE/hip.py + sed -i "s\'/opt/rocm/\os.getenv('ROCM_PATH', '/opt/rocm/')+'/\g" $BASE/hip.py python3 -c "import tinygrad.runtime.autogen.hip" clang2py /opt/rocm/include/amd_comgr/amd_comgr.h \ --clang-args="-D__HIP_PLATFORM_AMD__ -I/opt/rocm/include -x c++" -o $BASE/comgr.py -l /opt/rocm/lib/libamd_comgr.so fixup $BASE/comgr.py + sed -i "s\import ctypes\import ctypes, os\g" $BASE/comgr.py + sed -i "s\'/opt/rocm/\os.getenv('ROCM_PATH', '/opt/rocm/')+'/\g" $BASE/comgr.py python3 -c "import tinygrad.runtime.autogen.comgr" } @@ -66,6 +70,8 @@ generate_hsa() { --clang-args="-I/opt/rocm/include" \ -o $BASE/hsa.py -l /opt/rocm/lib/libhsa-runtime64.so fixup $BASE/hsa.py + sed -i "s\import ctypes\import ctypes, os\g" $BASE/hsa.py + sed -i "s\'/opt/rocm/\os.getenv('ROCM_PATH', '/opt/rocm/')+'/\g" $BASE/hsa.py python3 -c "import tinygrad.runtime.autogen.hsa" } diff --git a/tinygrad/runtime/autogen/comgr.py b/tinygrad/runtime/autogen/comgr.py index 582107c4ce..7aa19293ba 100644 --- a/tinygrad/runtime/autogen/comgr.py +++ b/tinygrad/runtime/autogen/comgr.py @@ -6,7 +6,7 @@ # POINTER_SIZE is: 8 # LONGDOUBLE_SIZE is: 16 # -import ctypes +import ctypes, os def string_cast(char_pointer, encoding='utf-8', errors='strict'): @@ -29,7 +29,7 @@ def char_pointer_cast(string, encoding='utf-8'): _libraries = {} -_libraries['libamd_comgr.so'] = ctypes.CDLL('/opt/rocm/lib/libamd_comgr.so') +_libraries['libamd_comgr.so'] = ctypes.CDLL(os.getenv('ROCM_PATH', '/opt/rocm/')+'/lib/libamd_comgr.so') c_int128 = ctypes.c_ubyte*16 c_uint128 = c_int128 void = None diff --git a/tinygrad/runtime/autogen/hip.py b/tinygrad/runtime/autogen/hip.py index b2e1a7de8d..fa8dbd1570 100644 --- a/tinygrad/runtime/autogen/hip.py +++ b/tinygrad/runtime/autogen/hip.py @@ -6,7 +6,7 @@ # POINTER_SIZE is: 8 # LONGDOUBLE_SIZE is: 16 # -import ctypes +import ctypes, os class AsDictMixin: @@ -155,7 +155,7 @@ def char_pointer_cast(string, encoding='utf-8'): -_libraries['libamdhip64.so'] = ctypes.CDLL('/opt/rocm/lib/libamdhip64.so') +_libraries['libamdhip64.so'] = ctypes.CDLL(os.getenv('ROCM_PATH', '/opt/rocm/')+'/lib/libamdhip64.so') diff --git a/tinygrad/runtime/autogen/hsa.py b/tinygrad/runtime/autogen/hsa.py index f7e6c9bb35..ce957f81b1 100644 --- a/tinygrad/runtime/autogen/hsa.py +++ b/tinygrad/runtime/autogen/hsa.py @@ -6,7 +6,7 @@ # POINTER_SIZE is: 8 # LONGDOUBLE_SIZE is: 16 # -import ctypes +import ctypes, os def string_cast(char_pointer, encoding='utf-8', errors='strict'): @@ -29,7 +29,7 @@ def char_pointer_cast(string, encoding='utf-8'): _libraries = {} -_libraries['libhsa-runtime64.so'] = ctypes.CDLL('/opt/rocm/lib/libhsa-runtime64.so') +_libraries['libhsa-runtime64.so'] = ctypes.CDLL(os.getenv('ROCM_PATH', '/opt/rocm/')+'/lib/libhsa-runtime64.so') class AsDictMixin: @classmethod def as_dict(cls, self):