mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
Add iree-run-module like tool for running in a vulkan session (#398)
This commit is contained in:
@@ -40,45 +40,77 @@ set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-samples-vulkan-gui")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
vulkan_inference_gui.cc
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-samples-vulkan-gui")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
|
||||
@@ -18,6 +18,12 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
@@ -30,6 +36,15 @@
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
@@ -38,6 +53,49 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
@@ -87,225 +145,6 @@ static void check_vk_result(VkResult err) {
|
||||
abort();
|
||||
}
|
||||
|
||||
// Helper function to find Vulkan memory type bits. See ImGui_ImplVulkan_MemoryType() in imgui_impl_vulkan.cpp
|
||||
uint32_t findMemoryType(uint32_t type_filter, VkMemoryPropertyFlags properties)
|
||||
{
|
||||
VkPhysicalDeviceMemoryProperties mem_properties;
|
||||
vkGetPhysicalDeviceMemoryProperties(g_PhysicalDevice, &mem_properties);
|
||||
|
||||
for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++)
|
||||
{
|
||||
if ((type_filter & (1 << i)) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0xFFFFFFFF; // Unable to find memoryType
|
||||
}
|
||||
|
||||
// Helper function to load an image with common settings and return a VkDescriptorSet as a sort of Vulkan pointer
|
||||
bool LoadTextureFromFile(const char* filename, VkDescriptorSet* img_ds, int* image_width, int* image_height)
|
||||
{
|
||||
// Specifying 4 channels forces stb to load the image in RGBA which is an easy format for Vulkan
|
||||
int image_channels = 4;
|
||||
unsigned char* image_data = stbi_load(filename, image_width, image_height, 0, image_channels);
|
||||
|
||||
if (image_data == NULL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate allocation size (in number of bytes)
|
||||
size_t image_size = (*image_width)*(*image_height)*image_channels;
|
||||
|
||||
VkResult err;
|
||||
|
||||
// Create the Vulkan image.
|
||||
VkImage texture_image;
|
||||
VkDeviceMemory texture_image_memory;
|
||||
{
|
||||
VkImageCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
|
||||
info.imageType = VK_IMAGE_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.extent.width = *image_width;
|
||||
info.extent.height = *image_height;
|
||||
info.extent.depth = 1;
|
||||
info.mipLevels = 1;
|
||||
info.arrayLayers = 1;
|
||||
info.samples = VK_SAMPLE_COUNT_1_BIT;
|
||||
info.tiling = VK_IMAGE_TILING_OPTIMAL;
|
||||
info.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
|
||||
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
info.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
err = vkCreateImage(g_Device, &info, g_Allocator, &texture_image);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetImageMemoryRequirements(g_Device, texture_image, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &texture_image_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindImageMemory(g_Device, texture_image, texture_image_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create the Image View
|
||||
VkImageView image_view;
|
||||
{
|
||||
VkImageViewCreateInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
|
||||
info.image = texture_image;
|
||||
info.viewType = VK_IMAGE_VIEW_TYPE_2D;
|
||||
info.format = VK_FORMAT_R8G8B8A8_UNORM;
|
||||
info.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
info.subresourceRange.levelCount = 1;
|
||||
info.subresourceRange.layerCount = 1;
|
||||
err = vkCreateImageView(g_Device, &info, g_Allocator, &image_view);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Sampler
|
||||
VkSampler sampler;
|
||||
{
|
||||
VkSamplerCreateInfo sampler_info{};
|
||||
sampler_info.sType = VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO;
|
||||
sampler_info.magFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.minFilter = VK_FILTER_LINEAR;
|
||||
sampler_info.mipmapMode = VK_SAMPLER_MIPMAP_MODE_LINEAR;
|
||||
sampler_info.addressModeU = VK_SAMPLER_ADDRESS_MODE_REPEAT; // outside image bounds just use border color
|
||||
sampler_info.addressModeV = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.addressModeW = VK_SAMPLER_ADDRESS_MODE_REPEAT;
|
||||
sampler_info.minLod = -1000;
|
||||
sampler_info.maxLod = 1000;
|
||||
sampler_info.maxAnisotropy = 1.0f;
|
||||
err = vkCreateSampler(g_Device, &sampler_info, g_Allocator, &sampler);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Create Descriptor Set using ImGUI's implementation
|
||||
*img_ds = ImGui_ImplVulkan_AddTexture(sampler, image_view, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
|
||||
|
||||
// Create Upload Buffer
|
||||
VkBuffer upload_buffer;
|
||||
VkDeviceMemory upload_buffer_memory;
|
||||
{
|
||||
VkBufferCreateInfo buffer_info = {};
|
||||
buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
buffer_info.size = image_size;
|
||||
buffer_info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
|
||||
buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
err = vkCreateBuffer(g_Device, &buffer_info, g_Allocator, &upload_buffer);
|
||||
check_vk_result(err);
|
||||
VkMemoryRequirements req;
|
||||
vkGetBufferMemoryRequirements(g_Device, upload_buffer, &req);
|
||||
VkMemoryAllocateInfo alloc_info = {};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
alloc_info.allocationSize = req.size;
|
||||
alloc_info.memoryTypeIndex = findMemoryType(req.memoryTypeBits, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
|
||||
err = vkAllocateMemory(g_Device, &alloc_info, g_Allocator, &upload_buffer_memory);
|
||||
check_vk_result(err);
|
||||
err = vkBindBufferMemory(g_Device, upload_buffer, upload_buffer_memory, 0);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Upload to Buffer:
|
||||
{
|
||||
void* map = NULL;
|
||||
err = vkMapMemory(g_Device, upload_buffer_memory, 0, image_size, 0, &map);
|
||||
check_vk_result(err);
|
||||
memcpy(map, image_data, image_size);
|
||||
VkMappedMemoryRange range[1] = {};
|
||||
range[0].sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
|
||||
range[0].memory = upload_buffer_memory;
|
||||
range[0].size = image_size;
|
||||
err = vkFlushMappedMemoryRanges(g_Device, 1, range);
|
||||
check_vk_result(err);
|
||||
vkUnmapMemory(g_Device, upload_buffer_memory);
|
||||
}
|
||||
|
||||
// Release image memory using stb
|
||||
stbi_image_free(image_data);
|
||||
|
||||
// Create a command buffer that will perform following steps when hit in the command queue.
|
||||
// TODO: this works in the example, but may need input if this is an acceptable way to access the pool/create the command buffer.
|
||||
VkCommandPool command_pool = g_MainWindowData.Frames[g_MainWindowData.FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer;
|
||||
{
|
||||
VkCommandBufferAllocateInfo alloc_info{};
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
|
||||
alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
|
||||
alloc_info.commandPool = command_pool;
|
||||
alloc_info.commandBufferCount = 1;
|
||||
|
||||
err = vkAllocateCommandBuffers(g_Device, &alloc_info, &command_buffer);
|
||||
check_vk_result(err);
|
||||
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Copy to Image
|
||||
{
|
||||
VkImageMemoryBarrier copy_barrier[1] = {};
|
||||
copy_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
copy_barrier[0].dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
copy_barrier[0].oldLayout = VK_IMAGE_LAYOUT_UNDEFINED;
|
||||
copy_barrier[0].newLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
copy_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
copy_barrier[0].image = texture_image;
|
||||
copy_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
copy_barrier[0].subresourceRange.levelCount = 1;
|
||||
copy_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, NULL, 0, NULL, 1, copy_barrier);
|
||||
|
||||
VkBufferImageCopy region = {};
|
||||
region.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
region.imageSubresource.layerCount = 1;
|
||||
region.imageExtent.width = *image_width;
|
||||
region.imageExtent.height = *image_height;
|
||||
region.imageExtent.depth = 1;
|
||||
vkCmdCopyBufferToImage(command_buffer, upload_buffer, texture_image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ®ion);
|
||||
|
||||
VkImageMemoryBarrier use_barrier[1] = {};
|
||||
use_barrier[0].sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
|
||||
use_barrier[0].srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
use_barrier[0].dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
|
||||
use_barrier[0].oldLayout = VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
|
||||
use_barrier[0].newLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
|
||||
use_barrier[0].srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
|
||||
use_barrier[0].image = texture_image;
|
||||
use_barrier[0].subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
|
||||
use_barrier[0].subresourceRange.levelCount = 1;
|
||||
use_barrier[0].subresourceRange.layerCount = 1;
|
||||
vkCmdPipelineBarrier(command_buffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, NULL, 0, NULL, 1, use_barrier);
|
||||
}
|
||||
|
||||
// End command buffer
|
||||
{
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
@@ -723,7 +562,16 @@ namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
fprintf(stdout, "starting yo\n");
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
@@ -835,8 +683,6 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
@@ -900,69 +746,44 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
iree_file_toc_t module_file_toc;
|
||||
const char network_model[] = "resnet50_tf.vmfb";
|
||||
fprintf(stdout, "Loading: %s\n", network_model);
|
||||
if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
{
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
static float input_res50[224*224*3];
|
||||
static float output_res50[1000];
|
||||
|
||||
char filename[] = "dog_imagenet.jpg";
|
||||
fprintf(stdout, "loading: %s\n", filename);
|
||||
int x,y,n;
|
||||
//unsigned char *image_raw = stbi_load(filename, &x, &y, &n, 3);
|
||||
stbi_load(filename, &x, &y, &n, 3);
|
||||
fprintf(stdout, "res: %i x %i x %i\n", x, y, n);
|
||||
|
||||
/* Preprocessing needs to go here. For now use a buffer preprocessed in python.
|
||||
|
||||
//convert image into floating point format
|
||||
for(int i=0;i<224*224*3;i++)
|
||||
{
|
||||
input_res50[i]= ((float)image_raw[i])/255.0f;
|
||||
}*/
|
||||
|
||||
std::ifstream fin("dog.bin", std::ifstream::in | std::ifstream::binary);
|
||||
fin.read((char*)input_res50, 224*224*3*sizeof(float));
|
||||
|
||||
// load image again so imgui can display it
|
||||
int my_image_width = 0;
|
||||
int my_image_height = 0;
|
||||
VkDescriptorSet my_image_texture = 0;
|
||||
bool ret = LoadTextureFromFile(filename, &my_image_texture, &my_image_width, &my_image_height);
|
||||
fprintf(stdout, "creating vulkan image: %s\n", ret ?"OK":"FAIL");
|
||||
IM_ASSERT(ret);
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
iree_instance,
|
||||
iree_const_byte_span_t{
|
||||
reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
module_file_toc.size},
|
||||
iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
// Query for details about what is in the loaded module.
|
||||
iree_vm_module_signature_t bytecode_module_signature =
|
||||
iree_vm_module_signature(bytecode_module);
|
||||
fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
bytecode_module_signature.export_function_count);
|
||||
for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
iree_vm_function_t function;
|
||||
IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
auto function_name = iree_vm_function_name(&function);
|
||||
auto function_signature = iree_vm_function_signature(&function);
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
(int)function_name.size, function_name.data,
|
||||
(int)function_signature.calling_convention.size,
|
||||
function_signature.calling_convention.data);
|
||||
}
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
@@ -988,33 +809,42 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
iree_hal_memory_type_t input_memory_type =
|
||||
static_cast<iree_hal_memory_type_t>(
|
||||
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
iree_hal_buffer_usage_t input_buffer_usage =
|
||||
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
iree_hal_buffer_params_t buffer_params;
|
||||
buffer_params.type = input_memory_type;
|
||||
buffer_params.usage = input_buffer_usage;
|
||||
buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
allocator,
|
||||
/*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
&input0_buffer_view));
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
@@ -1023,6 +853,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
@@ -1076,46 +907,11 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
// Read back the results.
|
||||
auto* output_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(
|
||||
iree_vm_list_get_ref_deref(outputs.get(),
|
||||
0,
|
||||
iree_hal_buffer_view_get_descriptor()));
|
||||
IREE_CHECK_OK(iree_hal_device_transfer_d2h(
|
||||
iree_vk_device,
|
||||
iree_hal_buffer_view_buffer(output_buffer_view),
|
||||
0,
|
||||
output_res50, sizeof(output_res50),
|
||||
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// find maxarg from results
|
||||
float max = 0.0f;
|
||||
int max_idx = -1;
|
||||
for(int i=0;i<1000;i++)
|
||||
{
|
||||
if (output_res50[i] > max)
|
||||
{
|
||||
max = output_res50[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
ImGui::Text("pointer = %p", my_image_texture);
|
||||
ImGui::Text("size = %d x %d", my_image_width, my_image_height);
|
||||
ImGui::Image((ImTextureID)my_image_texture, ImVec2(my_image_width, my_image_height));
|
||||
|
||||
// Display the latest computation output.
|
||||
ImGui::Text("Max idx = [%i]", max_idx);
|
||||
ImGui::Text("Max value = [%f]", max);
|
||||
|
||||
ImGui::Text("Resnet50 categories:");
|
||||
ImGui::PlotHistogram("Histogram", output_res50, IM_ARRAYSIZE(output_res50), 0, NULL, 0.0f, 1.0f, ImVec2(0,80));
|
||||
ImGui::Separator();
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
@@ -1137,6 +933,7 @@ extern "C" int iree_main(int argc, char** argv) {
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user