mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
support symbolic jit in HIP (#1877)
This commit is contained in:
@@ -60,8 +60,8 @@ class HIPProgram:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
class PackageStruct(ctypes.Structure):
|
||||
_fields_ = [(f'field{idx}', ctypes.c_void_p) for idx in range(len(args))]
|
||||
struct = PackageStruct(*[data._buf for data in args])
|
||||
_fields_ = [(f'field{idx}', ctypes.c_void_p if not isinstance(args[idx], int) else ctypes.c_int) for idx in range(len(args))]
|
||||
struct = PackageStruct(*[data._buf if not isinstance(data, int) else np.int32(data) for data in args])
|
||||
hip.hipModuleLaunchKernel(self.prgs[args[0]._device], global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
|
||||
if wait:
|
||||
hip.hipEventRecord(end)
|
||||
@@ -80,7 +80,7 @@ typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
extern "C" __global__
|
||||
""", launch_bounds=True,
|
||||
smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", uses_vload=True, uses_ptr_arithmetic=True,
|
||||
smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", uses_vload=True, uses_ptr_arithmetic=True, arg_int_prefix = "const int",
|
||||
half_prekernel = "#include <hip/hip_fp16.h>\nusing half4 = HIP_vector_type<half, 4>;" + """
|
||||
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
|
||||
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
|
||||
|
||||
Reference in New Issue
Block a user