fix hip invalid ordinal (#3440)

This commit is contained in:
nimlgen
2024-02-18 16:31:44 +03:00
committed by GitHub
parent 8c0e85fdaf
commit 5647148937

View File

@@ -2,7 +2,7 @@ import ctypes
from typing import Tuple
import tinygrad.runtime.autogen.hip as hip
from tinygrad.helpers import init_c_var, time_execution_cuda_style
from tinygrad.runtime.ops_hip import check
from tinygrad.runtime.ops_hip import check, hip_set_device
from tinygrad.runtime.graph.cuda import CUDAGraph
# TODO: this is only used in graph
@@ -12,7 +12,7 @@ class HIPGraph(CUDAGraph):
def __del__(self):
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
def set_device(self): check(hip.hipSetDevice(self.device))
def set_device(self): hip_set_device(self.device)
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
def graph_instantiate(self, graph):