Clean up SDL linking

This commit is contained in:
Anush Elangovan
2022-09-14 13:18:55 -07:00
parent cfd9733c2b
commit 174b171913
2 changed files with 80 additions and 84 deletions

View File

@@ -27,27 +27,6 @@ if(NOT SDL2_FOUND)
return() return()
endif() endif()
# Compile simple_mul.mlir to simple_mul.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=vulkan-spirv")
list(APPEND _COMPILE_ARGS "${CMAKE_CURRENT_SOURCE_DIR}/simple_mul.mlir")
list(APPEND _COMPILE_ARGS "-o")
list(APPEND _COMPILE_ARGS "simple_mul.vmfb")
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/simple_mul.vmfb
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "simple_mul.mlir"
)
# Embed simple_mul.vmfb into a C file as simple_mul_bytecode_module_c.[h/c]
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
set(_EMBED_ARGS)
list(APPEND _EMBED_ARGS "--output_header=simple_mul_bytecode_module_c.h")
list(APPEND _EMBED_ARGS "--output_impl=simple_mul_bytecode_module_c.c")
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vulkan_gui_simple_mul_bytecode_module")
list(APPEND _EMBED_ARGS "--flatten")
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/simple_mul.vmfb")
FetchContent_Declare( FetchContent_Declare(
imgui imgui
GIT_REPOSITORY https://github.com/ocornut/imgui GIT_REPOSITORY https://github.com/ocornut/imgui
@@ -77,11 +56,10 @@ target_sources(${_NAME}
) )
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-samples-vulkan-gui") set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-samples-vulkan-gui")
target_include_directories(${_NAME} PUBLIC target_include_directories(${_NAME} PUBLIC
/usr/include/SDL2
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}> $<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
) )
target_link_libraries(${_NAME} target_link_libraries(${_NAME}
SDL2 SDL2::SDL2
Vulkan::Vulkan Vulkan::Vulkan
iree_runtime_runtime iree_runtime_runtime
iree_base_internal_main iree_base_internal_main
@@ -91,11 +69,13 @@ target_link_libraries(${_NAME}
iree_vm_bytecode_module iree_vm_bytecode_module
iree_vm_cc iree_vm_cc
) )
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS") set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
else() else()
set(_GUI_LINKOPTS "") set(_GUI_LINKOPTS "")
endif() endif()
target_link_options(${_NAME} target_link_options(${_NAME}
PRIVATE PRIVATE
${_GUI_LINKOPTS} ${_GUI_LINKOPTS}

View File

@@ -13,14 +13,16 @@
#include <imgui_impl_vulkan.h> #include <imgui_impl_vulkan.h>
#include <vulkan/vulkan.h> #include <vulkan/vulkan.h>
#include <cstring> #include <cstring>
#include <set> #include <set>
#include <vector> #include <vector>
#include "iree/hal/drivers/vulkan/api.h"
// IREE's C API: // IREE's C API:
#include "iree/base/api.h" #include "iree/base/api.h"
#include "iree/hal/api.h" #include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h" #include "iree/hal/drivers/vulkan/registration/driver_module.h"
#include "iree/modules/hal/module.h" #include "iree/modules/hal/module.h"
#include "iree/vm/api.h" #include "iree/vm/api.h"
@@ -30,26 +32,23 @@
// Other dependencies (helpers, etc.) // Other dependencies (helpers, etc.)
#include "iree/base/internal/main.h" #include "iree/base/internal/main.h"
#define IMGUI_UNLIMITED_FRAME_RATE
#define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h" #include "stb_image.h"
// Compiled module embedded here to avoid file IO:
//#include "simple_mul_bytecode_module_c.h"
typedef struct iree_file_toc_t { typedef struct iree_file_toc_t {
const char* name; // the file's original name const char* name; // the file's original name
char* data; // beginning of the file char* data; // beginning of the file
size_t size; // length of the file size_t size; // length of the file
} iree_file_toc_t; } iree_file_toc_t;
bool load_file(const char* filename, char** pOut, size_t* pSize) bool load_file(const char* filename, char** pOut, size_t* pSize)
{ {
FILE* f = fopen(filename, "rb"); FILE* f = fopen(filename, "rb");
if (f == NULL) if (f == NULL)
{ {
printf("Can't open %s\n", filename); fprintf(stderr, "Can't open %s\n", filename);
return false; return false;
} }
@@ -63,7 +62,7 @@ bool load_file(const char* filename, char** pOut, size_t* pSize)
fclose(f); fclose(f);
return *pSize==size; return size != 0;
} }
static VkAllocationCallbacks* g_Allocator = NULL; static VkAllocationCallbacks* g_Allocator = NULL;
@@ -722,6 +721,9 @@ static void CleanupVulkanWindow() {
namespace iree { namespace iree {
extern "C" int iree_main(int argc, char** argv) { extern "C" int iree_main(int argc, char** argv) {
fprintf(stdout, "starting yo\n");
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Create a window. // Create a window.
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) { if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
@@ -738,7 +740,14 @@ extern "C" int iree_main(int argc, char** argv) {
SDL_Window* window = SDL_CreateWindow( SDL_Window* window = SDL_CreateWindow(
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED, "IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
if (window == nullptr)
{
const char* sdl_err = SDL_GetError();
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
abort();
return 1;
}
// Setup Vulkan // Setup Vulkan
iree_hal_vulkan_features_t iree_vulkan_features = iree_hal_vulkan_features_t iree_vulkan_features =
static_cast<iree_hal_vulkan_features_t>( static_cast<iree_hal_vulkan_features_t>(
@@ -757,7 +766,8 @@ extern "C" int iree_main(int argc, char** argv) {
VkSurfaceKHR surface; VkSurfaceKHR surface;
VkResult err; VkResult err;
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) { if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
printf("Failed to create Vulkan surface.\n"); fprintf(stderr, "Failed to create Vulkan surface.\n");
abort();
return 1; return 1;
} }
@@ -887,31 +897,34 @@ extern "C" int iree_main(int argc, char** argv) {
IREE_HAL_MODULE_FLAG_NONE, IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module)); iree_allocator_system(), &hal_module));
// Load bytecode module from embedded data.
fprintf(stdout, "Loading simple_mul.mlir...\n");
/*
const struct iree_file_toc_t* module_file_toc =
iree_samples_vulkan_gui_simple_mul_bytecode_module_create();
*/
// Load bytecode module
iree_file_toc_t module_file_toc; iree_file_toc_t module_file_toc;
load_file("resnet50_tf.vmfb", &module_file_toc.data, &module_file_toc.size); const char network_model[] = "amd-resnet50.vmfb";
fprintf(stdout, "module size: %lu\n", module_file_toc.size); fprintf(stdout, "Loading: %s\n", network_model);
if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
{
abort();
return 1;
}
fprintf(stdout, "module size: %zu\n", module_file_toc.size);
static float input_res50[224*224*3]; static float input_res50[224*224*3];
static float output_res50[1000]; static float output_res50[1000];
char filename[] = "vulkan_gui/dog_imagenet.jpg"; char filename[] = "dog_imagenet.jpg";
fprintf(stdout, "loading: %s\n", filename); fprintf(stdout, "loading: %s\n", filename);
int x,y,n; int x,y,n;
unsigned char *image_raw = stbi_load(filename, &x, &y, &n, 0); unsigned char *image_raw = stbi_load(filename, &x, &y, &n, 3);
fprintf(stdout, "res: %i x %i x %i\n", x, y, n); fprintf(stdout, "res: %i x %i x %i\n", x, y, n);
//convert image into floating point format
for(int i=0;i<224*224*3;i++) for(int i=0;i<224*224*3;i++)
{ {
input_res50[i]= ((float)image_raw[i])/255.0f; input_res50[i]= ((float)image_raw[i])/255.0f;
} }
// load image again so imgui can display it
int my_image_width = 0; int my_image_width = 0;
int my_image_height = 0; int my_image_height = 0;
VkDescriptorSet my_image_texture = 0; VkDescriptorSet my_image_texture = 0;
@@ -954,7 +967,7 @@ extern "C" int iree_main(int argc, char** argv) {
// Lookup the entry point function. // Lookup the entry point function.
iree_vm_function_t main_function; iree_vm_function_t main_function;
const char kMainFunctionName[] = "module.forward"; const char kMainFunctionName[] = "module.predict";
IREE_CHECK_OK(iree_vm_context_resolve_function( IREE_CHECK_OK(iree_vm_context_resolve_function(
iree_context, iree_context,
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1}, iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
@@ -962,8 +975,46 @@ extern "C" int iree_main(int argc, char** argv) {
iree_string_view_t main_function_name = iree_vm_function_name(&main_function); iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
fprintf(stdout, "Resolved main function named '%.*s'\n", fprintf(stdout, "Resolved main function named '%.*s'\n",
(int)main_function_name.size, main_function_name.data); (int)main_function_name.size, main_function_name.data);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Write inputs into mappable buffers.
iree_hal_allocator_t* allocator =
iree_hal_device_allocator(iree_vk_device);
iree_hal_memory_type_t input_memory_type =
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
iree_hal_buffer_usage_t input_buffer_usage =
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
iree_hal_buffer_params_t buffer_params;
buffer_params.type = input_memory_type;
buffer_params.usage = input_buffer_usage;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
// Wrap input buffers in buffer views.
iree_hal_buffer_view_t* input0_buffer_view = nullptr;
constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
allocator,
/*shape_rank=*/4, /*shape=*/input_buffer_shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
&input0_buffer_view));
vm::ref<iree_vm_list_t> inputs;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
// Prepare outputs list to accept results from the invocation.
vm::ref<iree_vm_list_t> outputs;
constexpr iree_hal_dim_t kOutputCount = 1000;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Main loop. // Main loop.
bool done = false; bool done = false;
@@ -1009,46 +1060,8 @@ extern "C" int iree_main(int argc, char** argv) {
// ImGui Inputs for two input tensors. // ImGui Inputs for two input tensors.
// Run computation whenever any of the values changes. // Run computation whenever any of the values changes.
static bool dirty = true; static bool dirty = false;
if (dirty) { if (dirty) {
// Some input values changed, run the computation.
// This is synchronous and doesn't reuse buffers for now.
// Write inputs into mappable buffers.
iree_hal_allocator_t* allocator =
iree_hal_device_allocator(iree_vk_device);
iree_hal_memory_type_t input_memory_type =
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
iree_hal_buffer_usage_t input_buffer_usage =
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
iree_hal_buffer_params_t buffer_params;
buffer_params.type = input_memory_type;
buffer_params.usage = input_buffer_usage;
// Wrap input buffers in buffer views.
iree_hal_buffer_view_t* input0_buffer_view = nullptr;
constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
allocator,
/*shape_rank=*/4, /*shape=*/input_buffer_shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
&input0_buffer_view));
vm::ref<iree_vm_list_t> inputs;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
// Prepare outputs list to accept results from the invocation.
vm::ref<iree_vm_list_t> outputs;
constexpr iree_hal_dim_t kOutputCount = 1000;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
// Synchronously invoke the function. // Synchronously invoke the function.
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function, IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
@@ -1068,9 +1081,11 @@ extern "C" int iree_main(int argc, char** argv) {
output_res50, sizeof(output_res50), output_res50, sizeof(output_res50),
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout())); IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
dirty = true; dirty = true;
} }
// find maxarg from results
float max = 0.0f; float max = 0.0f;
int max_idx = -1; int max_idx = -1;
for(int i=0;i<1000;i++) for(int i=0;i<1000;i++)
@@ -1090,6 +1105,7 @@ extern "C" int iree_main(int argc, char** argv) {
ImGui::Text("Max idx = [%i]", max_idx); ImGui::Text("Max idx = [%i]", max_idx);
ImGui::Text("Max value = [%f]", max); ImGui::Text("Max value = [%f]", max);
ImGui::Text("Resnet50 categories:");
ImGui::PlotHistogram("Histogram", output_res50, IM_ARRAYSIZE(output_res50), 0, NULL, 0.0f, 1.0f, ImVec2(0,80)); ImGui::PlotHistogram("Histogram", output_res50, IM_ARRAYSIZE(output_res50), 0, NULL, 0.0f, 1.0f, ImVec2(0,80));
ImGui::Separator(); ImGui::Separator();