added parallel cuda execution

This commit is contained in:
Elias
2022-03-15 00:01:57 +00:00
parent e6115da192
commit 2416936ecd
4 changed files with 176 additions and 15 deletions

View File

@@ -0,0 +1,123 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <inttypes.h>
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/base/internal/flags.h"
#include "iree/base/tracing.h"
#include "iree/hal/cuda/api.h"
#define IREE_HAL_CUDA_DRIVER_ID 0x43554441u // CUDA
// Force using CUDA streams until we support command buffer caching to avoid the
// overhead of graph creation.
IREE_FLAG(
bool, cuda_use_streams, true,
"Use CUDA streams for executing command buffers (instead of graphs).");
IREE_FLAG(bool, cuda_allow_inline_execution, false,
"Allow command buffers to execute inline against CUDA streams when "
"possible.");
IREE_FLAG(int32_t, cuda_default_index, 0, "Index of the default CUDA device.");
static iree_status_t iree_hal_cuda_driver_factory_enumerate(
void* self, const iree_hal_driver_info_t** out_driver_infos,
iree_host_size_t* out_driver_info_count) {
// NOTE: we could query supported cuda versions or featuresets here.
static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_id = IREE_HAL_CUDA_DRIVER_ID,
.driver_name = iree_string_view_literal("cuda"),
.full_name = iree_string_view_literal("CUDA (dynamic)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
return iree_ok_status();
}
static iree_status_t iree_hal_cuda_driver_factory_try_create0(
void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver with ID %016" PRIu64
" is provided by this factory",
driver_id);
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda_device_params_t default_params;
iree_hal_cuda_device_params_initialize(&default_params);
if (FLAG_cuda_use_streams) {
default_params.command_buffer_mode =
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
}
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
iree_hal_cuda_driver_options_t driver_options;
iree_hal_cuda_driver_options_initialize(&driver_options);
driver_options.default_device_index = 0;
iree_string_view_t identifier = iree_make_cstring_view("cuda");
iree_status_t status = iree_hal_cuda_driver_create(
identifier, &default_params, &driver_options, allocator, out_driver);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_cuda_driver_factory_try_create1(
void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver with ID %016" PRIu64
" is provided by this factory",
driver_id);
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda_device_params_t default_params;
iree_hal_cuda_device_params_initialize(&default_params);
if (FLAG_cuda_use_streams) {
default_params.command_buffer_mode =
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
}
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
iree_hal_cuda_driver_options_t driver_options;
iree_hal_cuda_driver_options_initialize(&driver_options);
driver_options.default_device_index = 1;
iree_string_view_t identifier = iree_make_cstring_view("cuda");
iree_status_t status = iree_hal_cuda_driver_create(
identifier, &default_params, &driver_options, allocator, out_driver);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry, int index) {
if(index == 1){
static const iree_hal_driver_factory_t factory = {
.self = NULL,
.enumerate = iree_hal_cuda_driver_factory_enumerate,
.try_create = iree_hal_cuda_driver_factory_try_create1,
};
return iree_hal_driver_registry_register_factory(registry, &factory);
} else {
static const iree_hal_driver_factory_t factory = {
.self = NULL,
.enumerate = iree_hal_cuda_driver_factory_enumerate,
.try_create = iree_hal_cuda_driver_factory_try_create0,
};
return iree_hal_driver_registry_register_factory(registry, &factory);
}
}

View File

@@ -0,0 +1,24 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
#define IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
#include "iree/base/api.h"
#include "iree/hal/api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
IREE_API_EXPORT iree_status_t
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry, int index);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_

View File

@@ -14,6 +14,8 @@
#include <libwebsockets.h>
#include "run-module.c"
#include "iree/tools/iree_translate_lib.h"
//#include "iree/compiler/Translation/HALExecutable.h"
#include <string.h>
#include <signal.h>
#if !defined(WIN32)
@@ -222,10 +224,16 @@ void sigint_handler(int sig)
interrupted = 1;
}
void* run_module_thread(void *i)
void* run_module_thread0(void *i)
{
int index = *(int *)i;
int x = run_module(index);
int index = *((int *)i);
int x = run_module(0);
return NULL;
}
void* run_module_thread1(void *i)
{
int x = run_module(1);
return NULL;
}
@@ -233,11 +241,9 @@ int main(int argc, const char **argv)
{
pthread_t t1, t2;
int i1, i2;
int32_t ci0 = 0;
int32_t ci1 = 1;
printf("Before Thread\n");
i1 = pthread_create(&t1, NULL, run_module_thread, (void *)0);
i2 = pthread_create(&t2, NULL, run_module_thread, (void *)1);
i1 = pthread_create(&t1, NULL, run_module_thread0, NULL);
i2 = pthread_create(&t2, NULL, run_module_thread1, NULL);
pthread_join(t1, NULL);
pthread_join(t2, NULL);
printf("After Thread\n");

View File

@@ -20,10 +20,13 @@
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/base/internal/flags.h"
#include "iree/hal/cuda/registration/driver_module.h"
#include "dshark_driver_module.c"
#include "iree/base/tracing.h"
#include "iree/hal/cuda/api.h"
#include "simple_embedding_test_bytecode_module_cuda_c.h"
#include "iree/hal/cuda/cuda_driver.c"
#include "iree/hal/cuda/cuda_device.h"
#include "iree/hal/cuda/cuda_device.c"
iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device, int index) {
@@ -34,14 +37,18 @@ iree_status_t create_sample_device(iree_allocator_t host_allocator,
// Create the HAL driver from the name.
iree_hal_driver_t* driver = NULL;
iree_string_view_t identifier = iree_make_cstring_view("cuda");
iree_status_t status = iree_hal_driver_registry_try_create_by_name(
iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
iree_hal_cuda_device_params_t* params;
iree_hal_cuda_device_params_initialize(params);
//iree_status_t status = iree_hal_cuda_driver_select_default_device(driver, NULL, index, host_allocator, out_device);
//if (iree_status_is_ok(status)) {
// iree_status_t status = iree_hal_driver_registry_try_create_by_name(
// iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
//}
// Create the default device (primary GPU).
if (iree_status_is_ok(status)) {
status = iree_hal_driver_create_default_device(driver, host_allocator,
iree_status_t status = iree_hal_cuda_device_create(driver, identifier, params, NULL, index, host_allocator,
out_device);
}
iree_hal_driver_release(driver);
return iree_ok_status();
@@ -65,6 +72,7 @@ iree_status_t Run(int index) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device, index),
"create device");
//device = (iree_hal_device_t)device;
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_module_create(device, iree_allocator_system(), &hal_module));
@@ -104,7 +112,7 @@ iree_status_t Run(int index) {
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
@@ -115,7 +123,7 @@ iree_status_t Run(int index) {
},
iree_make_const_byte_span(kFloat4, sizeof(kFloat4)), &arg0_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |