mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
added parallel cuda execution
This commit is contained in:
123
dSHARK/dshark_driver_module.c
Normal file
123
dSHARK/dshark_driver_module.c
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/hal/cuda/api.h"
|
||||
|
||||
#define IREE_HAL_CUDA_DRIVER_ID 0x43554441u // CUDA
|
||||
|
||||
// Force using CUDA streams until we support command buffer caching to avoid the
|
||||
// overhead of graph creation.
|
||||
IREE_FLAG(
|
||||
bool, cuda_use_streams, true,
|
||||
"Use CUDA streams for executing command buffers (instead of graphs).");
|
||||
|
||||
IREE_FLAG(bool, cuda_allow_inline_execution, false,
|
||||
"Allow command buffers to execute inline against CUDA streams when "
|
||||
"possible.");
|
||||
|
||||
IREE_FLAG(int32_t, cuda_default_index, 0, "Index of the default CUDA device.");
|
||||
|
||||
static iree_status_t iree_hal_cuda_driver_factory_enumerate(
|
||||
void* self, const iree_hal_driver_info_t** out_driver_infos,
|
||||
iree_host_size_t* out_driver_info_count) {
|
||||
// NOTE: we could query supported cuda versions or featuresets here.
|
||||
static const iree_hal_driver_info_t driver_infos[1] = {{
|
||||
.driver_id = IREE_HAL_CUDA_DRIVER_ID,
|
||||
.driver_name = iree_string_view_literal("cuda"),
|
||||
.full_name = iree_string_view_literal("CUDA (dynamic)"),
|
||||
}};
|
||||
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
|
||||
*out_driver_infos = driver_infos;
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
static iree_status_t iree_hal_cuda_driver_factory_try_create0(
|
||||
void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
|
||||
iree_hal_driver_t** out_driver) {
|
||||
IREE_ASSERT_ARGUMENT(out_driver);
|
||||
*out_driver = NULL;
|
||||
if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
|
||||
return iree_make_status(IREE_STATUS_UNAVAILABLE,
|
||||
"no driver with ID %016" PRIu64
|
||||
" is provided by this factory",
|
||||
driver_id);
|
||||
}
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
|
||||
iree_hal_cuda_device_params_t default_params;
|
||||
iree_hal_cuda_device_params_initialize(&default_params);
|
||||
if (FLAG_cuda_use_streams) {
|
||||
default_params.command_buffer_mode =
|
||||
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
|
||||
}
|
||||
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
|
||||
|
||||
iree_hal_cuda_driver_options_t driver_options;
|
||||
iree_hal_cuda_driver_options_initialize(&driver_options);
|
||||
driver_options.default_device_index = 0;
|
||||
|
||||
iree_string_view_t identifier = iree_make_cstring_view("cuda");
|
||||
iree_status_t status = iree_hal_cuda_driver_create(
|
||||
identifier, &default_params, &driver_options, allocator, out_driver);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return status;
|
||||
}
|
||||
static iree_status_t iree_hal_cuda_driver_factory_try_create1(
|
||||
void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
|
||||
iree_hal_driver_t** out_driver) {
|
||||
IREE_ASSERT_ARGUMENT(out_driver);
|
||||
*out_driver = NULL;
|
||||
if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
|
||||
return iree_make_status(IREE_STATUS_UNAVAILABLE,
|
||||
"no driver with ID %016" PRIu64
|
||||
" is provided by this factory",
|
||||
driver_id);
|
||||
}
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
|
||||
iree_hal_cuda_device_params_t default_params;
|
||||
iree_hal_cuda_device_params_initialize(&default_params);
|
||||
if (FLAG_cuda_use_streams) {
|
||||
default_params.command_buffer_mode =
|
||||
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
|
||||
}
|
||||
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
|
||||
|
||||
iree_hal_cuda_driver_options_t driver_options;
|
||||
iree_hal_cuda_driver_options_initialize(&driver_options);
|
||||
driver_options.default_device_index = 1;
|
||||
|
||||
iree_string_view_t identifier = iree_make_cstring_view("cuda");
|
||||
iree_status_t status = iree_hal_cuda_driver_create(
|
||||
identifier, &default_params, &driver_options, allocator, out_driver);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return status;
|
||||
}
|
||||
|
||||
IREE_API_EXPORT iree_status_t
|
||||
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry, int index) {
|
||||
if(index == 1){
|
||||
static const iree_hal_driver_factory_t factory = {
|
||||
.self = NULL,
|
||||
.enumerate = iree_hal_cuda_driver_factory_enumerate,
|
||||
.try_create = iree_hal_cuda_driver_factory_try_create1,
|
||||
};
|
||||
return iree_hal_driver_registry_register_factory(registry, &factory);
|
||||
} else {
|
||||
static const iree_hal_driver_factory_t factory = {
|
||||
.self = NULL,
|
||||
.enumerate = iree_hal_cuda_driver_factory_enumerate,
|
||||
.try_create = iree_hal_cuda_driver_factory_try_create0,
|
||||
};
|
||||
return iree_hal_driver_registry_register_factory(registry, &factory);
|
||||
}
|
||||
}
|
||||
24
dSHARK/dshark_driver_module.h
Normal file
24
dSHARK/dshark_driver_module.h
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#ifndef IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
|
||||
#define IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
|
||||
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
IREE_API_EXPORT iree_status_t
|
||||
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry, int index);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
#include <libwebsockets.h>
|
||||
#include "run-module.c"
|
||||
#include "iree/tools/iree_translate_lib.h"
|
||||
//#include "iree/compiler/Translation/HALExecutable.h"
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#if !defined(WIN32)
|
||||
@@ -222,10 +224,16 @@ void sigint_handler(int sig)
|
||||
interrupted = 1;
|
||||
}
|
||||
|
||||
void* run_module_thread(void *i)
|
||||
void* run_module_thread0(void *i)
|
||||
{
|
||||
int index = *(int *)i;
|
||||
int x = run_module(index);
|
||||
int index = *((int *)i);
|
||||
int x = run_module(0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void* run_module_thread1(void *i)
|
||||
{
|
||||
int x = run_module(1);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -233,11 +241,9 @@ int main(int argc, const char **argv)
|
||||
{
|
||||
pthread_t t1, t2;
|
||||
int i1, i2;
|
||||
int32_t ci0 = 0;
|
||||
int32_t ci1 = 1;
|
||||
printf("Before Thread\n");
|
||||
i1 = pthread_create(&t1, NULL, run_module_thread, (void *)0);
|
||||
i2 = pthread_create(&t2, NULL, run_module_thread, (void *)1);
|
||||
i1 = pthread_create(&t1, NULL, run_module_thread0, NULL);
|
||||
i2 = pthread_create(&t2, NULL, run_module_thread1, NULL);
|
||||
pthread_join(t1, NULL);
|
||||
pthread_join(t2, NULL);
|
||||
printf("After Thread\n");
|
||||
|
||||
@@ -20,10 +20,13 @@
|
||||
#include "iree/vm/api.h"
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/hal/cuda/registration/driver_module.h"
|
||||
#include "dshark_driver_module.c"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/hal/cuda/api.h"
|
||||
#include "simple_embedding_test_bytecode_module_cuda_c.h"
|
||||
#include "iree/hal/cuda/cuda_driver.c"
|
||||
#include "iree/hal/cuda/cuda_device.h"
|
||||
#include "iree/hal/cuda/cuda_device.c"
|
||||
|
||||
iree_status_t create_sample_device(iree_allocator_t host_allocator,
|
||||
iree_hal_device_t** out_device, int index) {
|
||||
@@ -34,14 +37,18 @@ iree_status_t create_sample_device(iree_allocator_t host_allocator,
|
||||
// Create the HAL driver from the name.
|
||||
iree_hal_driver_t* driver = NULL;
|
||||
iree_string_view_t identifier = iree_make_cstring_view("cuda");
|
||||
iree_status_t status = iree_hal_driver_registry_try_create_by_name(
|
||||
iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
|
||||
iree_hal_cuda_device_params_t* params;
|
||||
iree_hal_cuda_device_params_initialize(params);
|
||||
//iree_status_t status = iree_hal_cuda_driver_select_default_device(driver, NULL, index, host_allocator, out_device);
|
||||
|
||||
//if (iree_status_is_ok(status)) {
|
||||
// iree_status_t status = iree_hal_driver_registry_try_create_by_name(
|
||||
// iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
|
||||
//}
|
||||
// Create the default device (primary GPU).
|
||||
if (iree_status_is_ok(status)) {
|
||||
status = iree_hal_driver_create_default_device(driver, host_allocator,
|
||||
|
||||
iree_status_t status = iree_hal_cuda_device_create(driver, identifier, params, NULL, index, host_allocator,
|
||||
out_device);
|
||||
}
|
||||
|
||||
iree_hal_driver_release(driver);
|
||||
return iree_ok_status();
|
||||
@@ -65,6 +72,7 @@ iree_status_t Run(int index) {
|
||||
iree_hal_device_t* device = NULL;
|
||||
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device, index),
|
||||
"create device");
|
||||
//device = (iree_hal_device_t)device;
|
||||
iree_vm_module_t* hal_module = NULL;
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_hal_module_create(device, iree_allocator_system(), &hal_module));
|
||||
@@ -104,7 +112,7 @@ iree_status_t Run(int index) {
|
||||
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
|
||||
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
|
||||
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
|
||||
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
|
||||
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
|
||||
@@ -115,7 +123,7 @@ iree_status_t Run(int index) {
|
||||
},
|
||||
iree_make_const_byte_span(kFloat4, sizeof(kFloat4)), &arg0_buffer_view));
|
||||
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
|
||||
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
|
||||
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
|
||||
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
|
||||
|
||||
Reference in New Issue
Block a user