From 87be73570d5494f4d479a6a902aec8b3e7220fa9 Mon Sep 17 00:00:00 2001 From: zach Date: Tue, 20 Sep 2022 13:53:15 -0700 Subject: [PATCH] Add ExtismContext to SDK + better errors for failed register/update (#19) - Adds `ExtismContext` instead of global `PLUGINS` registry - Adds `extism_context_new`, `extism_context_free` and `extism_context_reset` - Requires updating nearly every SDK function to add context parameter - Renames some SDK functions to follow better naming conventions - `extism_plugin_register` -> `extism_plugin_new` - `extism_output_get` -> `extism_plugin_output_data` - `extism_output_length` -> `extism_plugin_output_length` - `extism_call` -> `extism_plugin_call` - Updates `extism_error` to return the context error when -1 issued for the plug-in ID - Adds `extism_plugin_free` to remove an existing plugin - Updates SDKs to include these functions - Updates SDK examples and comments Co-authored-by: Steve Manuel --- c/main.c | 15 +- cpp/example.cpp | 17 +- cpp/extism.h | 1 - cpp/extism.hpp | 94 ++++++++--- dune | 1 + extism.go | 120 +++++++++++--- go/main.go | 5 +- haskell/Example.hs | 24 +-- haskell/src/Extism.hs | 154 ++++++++++++------ manifest/src/lib.rs | 2 +- node/example.js | 29 ++-- node/index.js | 129 +++++++++++---- ocaml/bin/main.ml | 6 +- ocaml/lib/extism.ml | 145 ++++++++++++----- ocaml/lib/extism.mli | 21 ++- php/example/index.php | 3 +- php/src/Context.php | 53 +++++++ php/src/Plugin.php | 51 +++--- php/src/generate.php | 19 +-- python/example.py | 20 +-- python/extism/__init__.py | 2 +- python/extism/extism.py | 90 +++++++++-- ruby/example.rb | 17 +- ruby/lib/extism.rb | 141 ++++++++++++++--- runtime/build.rs | 1 + runtime/extism.h | 98 ++++++++++-- runtime/src/context.rs | 74 +++++++++ runtime/src/export.rs | 320 +++++++++++++++++++++++--------------- runtime/src/lib.rs | 16 +- runtime/src/manifest.rs | 17 +- runtime/src/memory.rs | 26 +++- runtime/src/plugin.rs | 28 ++-- runtime/src/plugin_ref.rs | 60 +++---- runtime/src/sdk.rs | 204 ++++++++++++++++++------ rust/build.rs | 22 +-- rust/src/bindings.rs | 38 ++++- rust/src/context.rs | 34 ++++ rust/src/lib.rs | 160 +++++-------------- rust/src/plugin.rs | 144 +++++++++++++++++ scripts/header.py | 42 +++-- 40 files changed, 1764 insertions(+), 679 deletions(-) delete mode 120000 cpp/extism.h create mode 100644 dune create mode 100644 php/src/Context.php create mode 100644 runtime/src/context.rs create mode 100644 rust/src/context.rs create mode 100644 rust/src/plugin.rs diff --git a/c/main.c b/c/main.c index a6fb88a..633b560 100644 --- a/c/main.c +++ b/c/main.c @@ -36,20 +36,25 @@ int main(int argc, char *argv[]) { fputs("Not enough arguments\n", stderr); exit(1); } + + ExtismContext *ctx = extism_context_new(); + size_t len = 0; uint8_t *data = read_file("../wasm/code.wasm", &len); - ExtismPlugin plugin = extism_plugin_register(data, len, false); + ExtismPlugin plugin = extism_plugin_new(ctx, data, len, false); free(data); if (plugin < 0) { exit(1); } - assert(extism_call(plugin, "count_vowels", (uint8_t *)argv[1], - strlen(argv[1])) == 0); - ExtismSize out_len = extism_output_length(plugin); - const uint8_t *output = extism_output_get(plugin); + assert(extism_plugin_call(ctx, plugin, "count_vowels", (uint8_t *)argv[1], + strlen(argv[1])) == 0); + ExtismSize out_len = extism_plugin_output_length(ctx, plugin); + const uint8_t *output = extism_plugin_output_data(ctx, plugin); write(STDOUT_FILENO, output, out_len); write(STDOUT_FILENO, "\n", 1); + extism_plugin_free(ctx, plugin); + extism_context_free(ctx); return 0; } diff --git a/cpp/example.cpp b/cpp/example.cpp index 0c2d7c2..6432b1b 100644 --- a/cpp/example.cpp +++ b/cpp/example.cpp @@ -14,17 +14,14 @@ std::vector read(const char *filename) { int main(int argc, char *argv[]) { auto wasm = read("../wasm/code.wasm"); - Plugin plugin(wasm); + Context context = Context(); - if (argc < 2) { - std::cout << "Not enough arguments" << std::endl; - return 1; - } + Plugin plugin = context.plugin(wasm); - auto input = std::vector((uint8_t *)argv[1], - (uint8_t *)argv[1] + strlen(argv[1])); - auto output = plugin.call("count_vowels", input); - std::string str(output.begin(), output.end()); - std::cout << str << std::endl; + const char *input = argc > 1 ? argv[1] : "this is a test"; + ExtismSize length = strlen(input); + + extism::Buffer output = plugin.call("count_vowels", (uint8_t *)input, length); + std::cout << (char *)output.data << std::endl; return 0; } diff --git a/cpp/extism.h b/cpp/extism.h deleted file mode 120000 index c567ae3..0000000 --- a/cpp/extism.h +++ /dev/null @@ -1 +0,0 @@ -../core/extism.h \ No newline at end of file diff --git a/cpp/extism.hpp b/cpp/extism.hpp index 2cc5d53..f648afb 100644 --- a/cpp/extism.hpp +++ b/cpp/extism.hpp @@ -1,10 +1,11 @@ #pragma once +#include #include #include extern "C" { -#include "extism.h" +#include } namespace extism { @@ -17,29 +18,54 @@ public: const char *what() { return message.c_str(); } }; +class Buffer { +public: + Buffer(const uint8_t *ptr, ExtismSize len) : data(ptr), length(len) {} + const uint8_t *data; + ExtismSize length; + + operator std::string() { return std::string((const char *)data, length); } + operator std::vector() { + return std::vector(data, data + length); + } +}; + class Plugin { + std::shared_ptr context; ExtismPlugin plugin; public: - Plugin(const uint8_t *wasm, size_t length, bool with_wasi = false) { - this->plugin = extism_plugin_register(wasm, length, with_wasi); + Plugin(std::shared_ptr ctx, const uint8_t *wasm, + ExtismSize length, bool with_wasi = false) { + this->plugin = extism_plugin_new(ctx.get(), wasm, length, with_wasi); if (this->plugin < 0) { - throw Error("Unable to load plugin"); + const char *err = extism_error(ctx.get(), -1); + throw Error(err == nullptr ? "Unable to load plugin" : err); + } + this->context = ctx; + } + + ~Plugin() { + extism_plugin_free(this->context.get(), this->plugin); + this->plugin = -1; + } + + void update(const uint8_t *wasm, size_t length, bool with_wasi = false) { + bool b = extism_plugin_update(this->context.get(), this->plugin, wasm, + length, with_wasi); + if (!b) { + const char *err = extism_error(this->context.get(), -1); + throw Error(err == nullptr ? "Unable to update plugin" : err); } } - Plugin(const std::string &s, bool with_wasi = false) - : Plugin((const uint8_t *)s.c_str(), s.size(), with_wasi) {} - Plugin(const std::vector &s, bool with_wasi = false) - : Plugin(s.data(), s.size(), with_wasi) {} + Buffer call(const std::string &func, const uint8_t *input, + ExtismSize input_length) { - std::vector call(const std::string &func, - std::vector input) { - - int32_t rc = - extism_call(this->plugin, func.c_str(), input.data(), input.size()); + int32_t rc = extism_plugin_call(this->context.get(), this->plugin, + func.c_str(), input, input_length); if (rc != 0) { - const char *error = extism_error(this->plugin); + const char *error = extism_error(this->context.get(), this->plugin); if (error == nullptr) { throw Error("extism_call failed"); } @@ -47,10 +73,42 @@ public: throw Error(error); } - ExtismSize length = extism_output_length(this->plugin); - const uint8_t *ptr = extism_output_get(this->plugin); - std::vector out = std::vector(ptr, ptr + length); - return out; + ExtismSize length = + extism_plugin_output_length(this->context.get(), this->plugin); + const uint8_t *ptr = + extism_plugin_output_data(this->context.get(), this->plugin); + return Buffer(ptr, length); + } + + Buffer call(const std::string &func, const std::vector &input) { + return this->call(func, input.data(), input.size()); + } + + Buffer call(const std::string &func, const std::string &input) { + return this->call(func, (const uint8_t *)input.c_str(), input.size()); } }; + +class Context { +public: + std::shared_ptr pointer; + Context() { + this->pointer = std::shared_ptr(extism_context_new(), + extism_context_free); + } + + Plugin plugin(const uint8_t *wasm, size_t length, bool with_wasi = false) { + return Plugin(this->pointer, wasm, length, with_wasi); + } + + Plugin plugin(const std::string &str, bool with_wasi = false) { + return Plugin(this->pointer, (const uint8_t *)str.c_str(), str.size(), + with_wasi); + } + + Plugin plugin(const std::vector &data, bool with_wasi = false) { + return Plugin(this->pointer, data.data(), data.size(), with_wasi); + } +}; + } // namespace extism diff --git a/dune b/dune new file mode 100644 index 0000000..e61b0a6 --- /dev/null +++ b/dune @@ -0,0 +1 @@ +(dirs ocaml) diff --git a/extism.go b/extism.go index 29be415..4792003 100644 --- a/extism.go +++ b/extism.go @@ -15,8 +15,29 @@ import ( */ import "C" +// Context is used to manage Plugins +type Context struct { + pointer *C.ExtismContext +} + +// NewContext creates a new context, it should be freed using the `Free` method +func NewContext() Context { + p := C.extism_context_new() + return Context{ + pointer: p, + } +} + +// Free a context +func (ctx *Context) Free() { + C.extism_context_free(ctx.pointer) + ctx.pointer = nil +} + +// Plugin is used to call WASM functions type Plugin struct { - id int32 + ctx *Context + id int32 } type WasmData struct { @@ -57,6 +78,7 @@ func makePointer(data []byte) unsafe.Pointer { return ptr } +// SetLogFile sets the log file and level, this is a global setting func SetLogFile(filename string, level string) bool { name := C.CString(filename) l := C.CString(level) @@ -66,81 +88,120 @@ func SetLogFile(filename string, level string) bool { return bool(r) } -func register(data []byte, wasi bool) (Plugin, error) { +func register(ctx *Context, data []byte, wasi bool) (Plugin, error) { ptr := makePointer(data) - plugin := C.extism_plugin_register( + plugin := C.extism_plugin_new( + ctx.pointer, (*C.uchar)(ptr), C.uint64_t(len(data)), C._Bool(wasi), ) if plugin < 0 { - return Plugin{id: -1}, errors.New("Unable to load plugin") + err := C.extism_error(ctx.pointer, C.int32_t(-1)) + msg := "Unknown" + if err != nil { + msg = C.GoString(err) + } + + return Plugin{id: -1}, errors.New( + fmt.Sprintf("Unable to load plugin: %s", msg), + ) } - return Plugin{id: int32(plugin)}, nil + return Plugin{id: int32(plugin), ctx: ctx}, nil } -func update(plugin int32, data []byte, wasi bool) bool { +func update(ctx *Context, plugin int32, data []byte, wasi bool) error { ptr := makePointer(data) - return bool(C.extism_plugin_update( + b := bool(C.extism_plugin_update( + ctx.pointer, C.int32_t(plugin), (*C.uchar)(ptr), C.uint64_t(len(data)), C._Bool(wasi), )) + + if b { + return nil + } + + err := C.extism_error(ctx.pointer, C.int32_t(-1)) + msg := "Unknown" + if err != nil { + msg = C.GoString(err) + } + + return errors.New( + fmt.Sprintf("Unable to load plugin: %s", msg), + ) } -func LoadManifest(manifest Manifest, wasi bool) (Plugin, error) { +// PluginFromManifest creates a plugin from a `Manifest` +func (ctx *Context) PluginFromManifest(manifest Manifest, wasi bool) (Plugin, error) { data, err := json.Marshal(manifest) if err != nil { return Plugin{id: -1}, err } - return register(data, wasi) + return register(ctx, data, wasi) } -func LoadPlugin(module io.Reader, wasi bool) (Plugin, error) { +// Plugin creates a plugin from a WASM module +func (ctx *Context) Plugin(module io.Reader, wasi bool) (Plugin, error) { wasm, err := io.ReadAll(module) if err != nil { return Plugin{id: -1}, err } - return register(wasm, wasi) + return register(ctx, wasm, wasi) } -func (p *Plugin) Update(module io.Reader, wasi bool) (bool, error) { +// Update a plugin with a new WASM module +func (p *Plugin) Update(module io.Reader, wasi bool) error { wasm, err := io.ReadAll(module) if err != nil { - return false, err + return err } - return update(p.id, wasm, wasi), nil + return update(p.ctx, p.id, wasm, wasi) } -func (p *Plugin) UpdateManifest(manifest Manifest, wasi bool) (bool, error) { +// Update a plugin with a new Manifest +func (p *Plugin) UpdateManifest(manifest Manifest, wasi bool) error { data, err := json.Marshal(manifest) if err != nil { - return false, err + return err } - return update(p.id, data, wasi), nil + return update(p.ctx, p.id, data, wasi) } +// Set configuration values func (plugin Plugin) SetConfig(data map[string][]byte) error { s, err := json.Marshal(data) if err != nil { return err } ptr := makePointer(s) - C.extism_plugin_config(C.int(plugin.id), (*C.uchar)(ptr), C.uint64_t(len(s))) + C.extism_plugin_config(plugin.ctx.pointer, C.int(plugin.id), (*C.uchar)(ptr), C.uint64_t(len(s))) return nil } +/// FunctionExists returns true when the name function is present in the plugin +func (plugin Plugin) FunctionExists(functionName string) bool { + name := C.CString(functionName) + b := C.extism_plugin_function_exists(plugin.ctx.pointer, C.int(plugin.id), name) + C.free(unsafe.Pointer(name)) + return bool(b) +} + +/// Call a function by name with the given input, returning the output func (plugin Plugin) Call(functionName string, input []byte) ([]byte, error) { ptr := makePointer(input) name := C.CString(functionName) - rc := C.extism_call( + rc := C.extism_plugin_call( + plugin.ctx.pointer, C.int32_t(plugin.id), name, (*C.uchar)(ptr), @@ -149,7 +210,7 @@ func (plugin Plugin) Call(functionName string, input []byte) ([]byte, error) { C.free(unsafe.Pointer(name)) if rc != 0 { - err := C.extism_error(C.int32_t(plugin.id)) + err := C.extism_error(plugin.ctx.pointer, C.int32_t(plugin.id)) msg := "" if err != nil { msg = C.GoString(err) @@ -160,14 +221,27 @@ func (plugin Plugin) Call(functionName string, input []byte) ([]byte, error) { ) } - length := C.extism_output_length(C.int32_t(plugin.id)) + length := C.extism_plugin_output_length(plugin.ctx.pointer, C.int32_t(plugin.id)) if length > 0 { - x := C.extism_output_get(C.int32_t(plugin.id)) + x := C.extism_plugin_output_data(plugin.ctx.pointer, C.int32_t(plugin.id)) y := (*[]byte)(unsafe.Pointer(&x)) return []byte((*y)[0:length]), nil - } return []byte{}, nil } + +// Free a plugin +func (plugin *Plugin) Free() { + if plugin.ctx.pointer == nil { + return + } + C.extism_plugin_free(plugin.ctx.pointer, C.int32_t(plugin.id)) + plugin.id = -1 +} + +// Reset removes all registered plugins in a Context +func (ctx Context) Reset() { + C.extism_context_reset(ctx.pointer) +} diff --git a/go/main.go b/go/main.go index 1f4b6c4..d50887c 100644 --- a/go/main.go +++ b/go/main.go @@ -9,6 +9,9 @@ import ( ) func main() { + ctx := extism.NewContext() + defer ctx.Free() // this will free the context and all associated plugins + // set some input data to provide to the plugin module var data []byte if len(os.Args) > 1 { @@ -18,7 +21,7 @@ func main() { } manifest := extism.Manifest{Wasm: []extism.Wasm{extism.WasmFile{Path: "../wasm/code.wasm"}}} - plugin, err := extism.LoadManifest(manifest, false) + plugin, err := ctx.PluginFromManifest(manifest, false) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/haskell/Example.hs b/haskell/Example.hs index 9532671..c258a08 100644 --- a/haskell/Example.hs +++ b/haskell/Example.hs @@ -5,13 +5,19 @@ import qualified Data.ByteString as B import Extism import Extism.Manifest -main = do - plugin <- Extism.registerManifest (manifest [wasmFile "../wasm/code.wasm"]) False +try f (Right x) = f x +try f (Left (Error msg)) = do + _ <- putStrLn msg + exitFailure + +handlePlugin plugin = do res <- Extism.call plugin "count_vowels" (Extism.toByteString "this is a test") - case res of - Right (Error msg) -> do - _ <- putStrLn msg - exitFailure - Left bs -> do - _ <- putStrLn (Extism.fromByteString bs) - exitSuccess + try (\bs -> do + _ <- putStrLn (Extism.fromByteString bs) + _ <- Extism.free plugin + exitSuccess) res + +main = do + context <- Extism.newContext () + plugin <- Extism.pluginFromManifest context (manifest [wasmFile "../wasm/code.wasm"]) False + try handlePlugin plugin \ No newline at end of file diff --git a/haskell/src/Extism.hs b/haskell/src/Extism.hs index 33859ab..65b522b 100644 --- a/haskell/src/Extism.hs +++ b/haskell/src/Extism.hs @@ -5,6 +5,7 @@ import GHC.Int import GHC.Word import Foreign.C.Types import Foreign.Ptr +import Foreign.ForeignPtr import Foreign.C.String import Control.Monad (void) import Data.ByteString as B @@ -13,58 +14,105 @@ import Data.ByteString.Unsafe (unsafeUseAsCString) import Text.JSON (JSON, toJSObject, encode) import Extism.Manifest (Manifest, toString) -foreign import ccall unsafe "extism.h extism_plugin_register" extism_plugin_register :: Ptr Word8 -> Word64 -> CBool -> IO Int32 -foreign import ccall unsafe "extism.h extism_plugin_update" extism_plugin_update :: Int32 -> Ptr Word8 -> Word64 -> CBool -> IO CBool -foreign import ccall unsafe "extism.h extism_call" extism_call :: Int32 -> CString -> Ptr Word8 -> Word64 -> IO Int32 -foreign import ccall unsafe "extism.h extism_function_exists" extism_function_exists :: Int32 -> CString -> IO CBool -foreign import ccall unsafe "extism.h extism_error" extism_error :: Int32 -> IO CString -foreign import ccall unsafe "extism.h extism_output_length" extism_output_length :: Int32 -> IO Word64 -foreign import ccall unsafe "extism.h extism_output_get" extism_output_get :: Int32 -> IO (Ptr Word8) -foreign import ccall unsafe "extism.h extism_log_file" extism_log_file :: CString -> CString -> IO CBool -foreign import ccall unsafe "extism.h extism_plugin_config" extism_plugin_config :: Int32 -> Ptr Word8 -> Int64 -> IO CBool +newtype ExtismContext = ExtismContext () deriving Show -newtype Plugin = Plugin Int32 deriving Show +foreign import ccall unsafe "extism.h extism_context_new" extism_context_new :: IO (Ptr ExtismContext) +foreign import ccall unsafe "extism.h &extism_context_free" extism_context_free :: FunPtr (Ptr ExtismContext -> IO ()) +foreign import ccall unsafe "extism.h extism_plugin_new" extism_plugin_new :: Ptr ExtismContext -> Ptr Word8 -> Word64 -> CBool -> IO Int32 +foreign import ccall unsafe "extism.h extism_plugin_update" extism_plugin_update :: Ptr ExtismContext -> Int32 -> Ptr Word8 -> Word64 -> CBool -> IO CBool +foreign import ccall unsafe "extism.h extism_plugin_call" extism_plugin_call :: Ptr ExtismContext -> Int32 -> CString -> Ptr Word8 -> Word64 -> IO Int32 +foreign import ccall unsafe "extism.h extism_plugin_function_exists" extism_plugin_function_exists :: Ptr ExtismContext -> Int32 -> CString -> IO CBool +foreign import ccall unsafe "extism.h extism_error" extism_error :: Ptr ExtismContext -> Int32 -> IO CString +foreign import ccall unsafe "extism.h extism_plugin_output_length" extism_plugin_output_length :: Ptr ExtismContext -> Int32 -> IO Word64 +foreign import ccall unsafe "extism.h extism_plugin_output_data" extism_plugin_output_data :: Ptr ExtismContext -> Int32 -> IO (Ptr Word8) +foreign import ccall unsafe "extism.h extism_log_file" extism_log_file :: CString -> CString -> IO CBool +foreign import ccall unsafe "extism.h extism_plugin_config" extism_plugin_config :: Ptr ExtismContext -> Int32 -> Ptr Word8 -> Int64 -> IO CBool +foreign import ccall unsafe "extism.h extism_plugin_free" extism_plugin_free :: Ptr ExtismContext -> Int32 -> IO () +foreign import ccall unsafe "extism.h extism_context_reset" extism_context_reset :: Ptr ExtismContext -> IO () + +-- Context manages plugins +newtype Context = Context (ForeignPtr ExtismContext) + +-- Plugins can be used to call WASM function +data Plugin = Plugin Context Int32 + +-- Extism error newtype Error = Error String deriving Show +-- Helper function to convert a string to a bytestring toByteString :: String -> ByteString toByteString x = B.pack (Prelude.map c2w x) +-- Helper function to convert a bytestring to a string fromByteString :: ByteString -> String fromByteString bs = Prelude.map w2c $ B.unpack bs -register :: B.ByteString -> Bool -> IO Plugin -register wasm useWasi = +-- Remove all registered plugins in a Context +reset :: Context -> IO () +reset (Context ctx) = + withForeignPtr ctx (\ctx -> + extism_context_reset ctx) + +-- Create a new context +newContext :: () -> IO Context +newContext () = do + ptr <- extism_context_new + fptr <- newForeignPtr extism_context_free ptr + return (Context fptr) + +-- Create a plugin from a WASM module, `useWasi` determines if WASI should +-- be linked +plugin :: Context -> B.ByteString -> Bool -> IO (Either Error Plugin) +plugin c wasm useWasi = let length = fromIntegral (B.length wasm) in let wasi = fromInteger (if useWasi then 1 else 0) in + let Context ctx = c in do - p <- unsafeUseAsCString wasm (\s -> - extism_plugin_register (castPtr s) length wasi) - return $ Plugin p + withForeignPtr ctx (\ctx -> do + p <- unsafeUseAsCString wasm (\s -> + extism_plugin_new ctx (castPtr s) length wasi) + if p < 0 then do + err <- extism_error ctx (-1) + e <- peekCString err + return $ Left (Error e) + else + return $ Right (Plugin c p)) -registerManifest :: Manifest -> Bool -> IO Plugin -registerManifest manifest useWasi = +-- Create a plugin from a Manifest +pluginFromManifest :: Context -> Manifest -> Bool -> IO (Either Error Plugin) +pluginFromManifest ctx manifest useWasi = let wasm = toByteString $ toString manifest in - register wasm useWasi + plugin ctx wasm useWasi -update :: Plugin -> B.ByteString -> Bool -> IO Bool -update (Plugin id) wasm useWasi = +-- Update a plugin with a new WASM module +update :: Plugin -> B.ByteString -> Bool -> IO (Either Error ()) +update (Plugin (Context ctx) id) wasm useWasi = let length = fromIntegral (B.length wasm) in let wasi = fromInteger (if useWasi then 1 else 0) in do - b <- unsafeUseAsCString wasm (\s -> - extism_plugin_update id (castPtr s) length wasi) - return (b > 0) + withForeignPtr ctx (\ctx -> do + b <- unsafeUseAsCString wasm (\s -> + extism_plugin_update ctx id (castPtr s) length wasi) + if b <= 0 then do + err <- extism_error ctx (-1) + e <- peekCString err + return $ Left (Error e) + else + return (Right ())) -updateManifest :: Plugin -> Manifest -> Bool -> IO Bool +-- Update a plugin with a new Manifest +updateManifest :: Plugin -> Manifest -> Bool -> IO (Either Error ()) updateManifest plugin manifest useWasi = let wasm = toByteString $ toString manifest in update plugin wasm useWasi +-- Check if a plugin is value isValid :: Plugin -> Bool -isValid (Plugin p) = p >= 0 +isValid (Plugin _ p) = p >= 0 -setConfig :: Plugin -> [(String, String)] -> IO () -setConfig (Plugin plugin) x = +-- Set configuration values for a plugin +setConfig :: Plugin -> [(String, Maybe String)] -> IO () +setConfig (Plugin (Context ctx) plugin) x = if plugin < 0 then return () else @@ -72,34 +120,46 @@ setConfig (Plugin plugin) x = let bs = toByteString (encode obj) in let length = fromIntegral (B.length bs) in unsafeUseAsCString bs (\s -> do - void $ extism_plugin_config plugin (castPtr s) length) + withForeignPtr ctx (\ctx -> + void $ extism_plugin_config ctx plugin (castPtr s) length)) +-- Set the log file and level, this is a global configuration setLogFile :: String -> String -> IO () setLogFile filename level = withCString filename (\f -> withCString level (\l -> do void $ extism_log_file f l)) +-- Check if a function exists in the given plugin functionExists :: Plugin -> String -> IO Bool -functionExists (Plugin plugin) name = do - b <- withCString name (extism_function_exists plugin) - if b == 1 then return True else return False +functionExists (Plugin (Context ctx) plugin) name = do + withForeignPtr ctx (\ctx -> do + b <- withCString name (extism_plugin_function_exists ctx plugin) + if b == 1 then return True else return False) -call :: Plugin -> String -> B.ByteString -> IO (Either B.ByteString Error) -call (Plugin plugin) name input = +--- Call a function provided by the given plugin +call :: Plugin -> String -> B.ByteString -> IO (Either Error B.ByteString) +call (Plugin (Context ctx) plugin) name input = let length = fromIntegral (B.length input) in do - rc <- withCString name (\name -> - unsafeUseAsCString input (\input -> - extism_call plugin name (castPtr input) length)) - err <- extism_error plugin - if err /= nullPtr - then do e <- peekCString err - return $ Right (Error e) - else if rc == 0 - then do - length <- extism_output_length plugin - ptr <- extism_output_get plugin - buf <- packCStringLen (castPtr ptr, fromIntegral length) - return $ Left buf - else return $ Right (Error "Call failed") + withForeignPtr ctx (\ctx -> do + rc <- withCString name (\name -> + unsafeUseAsCString input (\input -> + extism_plugin_call ctx plugin name (castPtr input) length)) + err <- extism_error ctx plugin + if err /= nullPtr + then do e <- peekCString err + return $ Left (Error e) + else if rc == 0 + then do + length <- extism_plugin_output_length ctx plugin + ptr <- extism_plugin_output_data ctx plugin + buf <- packCStringLen (castPtr ptr, fromIntegral length) + return $ Right buf + else return $ Left (Error "Call failed")) + +-- Free a plugin +free :: Plugin -> IO () +free (Plugin (Context ctx) plugin) = + withForeignPtr ctx (\ctx -> + extism_plugin_free ctx plugin) diff --git a/manifest/src/lib.rs b/manifest/src/lib.rs index d309ec1..b6b4432 100644 --- a/manifest/src/lib.rs +++ b/manifest/src/lib.rs @@ -89,6 +89,6 @@ mod base64 { pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { let base64 = String::deserialize(d)?; - base64::decode(base64.as_bytes()).map_err(|e| serde::de::Error::custom(e)) + base64::decode(base64.as_bytes()).map_err(serde::de::Error::custom) } } diff --git a/node/example.js b/node/example.js index 0d951dd..cf9add1 100644 --- a/node/example.js +++ b/node/example.js @@ -1,13 +1,22 @@ -import { Plugin } from './index.js'; +import { withContext, Context } from './index.js'; import { readFileSync } from 'fs'; +withContext(async function (context) { + let wasm = readFileSync('../wasm/code.wasm'); + let p = context.plugin(wasm); + + if (!p.functionExists('count_vowels')) { + console.log("no function 'count_vowels' in wasm"); + process.exit(1); + } + + let buf = await p.call('count_vowels', process.argv[2] || 'this is a test'); + console.log(JSON.parse(buf.toString())['count']); + p.free(); +}); + +// or, use a context like this: +let ctx = new Context(); let wasm = readFileSync('../wasm/code.wasm'); -let p = new Plugin(wasm); - -if (!p.function_exists('count_vowels')) { - console.log("no function 'count_vowels' in wasm"); - process.exit(1); -} - -let buf = await p.call('count_vowels', process.argv[2] || 'this is a test'); -console.log(JSON.parse(buf.toString())['count']); +let p = ctx.plugin(wasm); +// ... where the context can be passed around to various functions etc. \ No newline at end of file diff --git a/node/index.js b/node/index.js index 4db2dcc..387bcf0 100644 --- a/node/index.js +++ b/node/index.js @@ -3,17 +3,21 @@ import path from 'path'; import url from 'url'; const __dirname = url.fileURLToPath(new URL('.', import.meta.url)); - +let context = 'void*'; let _functions = { - extism_plugin_register: ['int32', ['string', 'uint64', 'bool']], - extism_plugin_update: ['bool', ['int32', 'string', 'uint64', 'bool']], - extism_error: ['char*', ['int32']], - extism_call: ['int32', ['int32', 'string', 'string', 'uint64']], - extism_output_length: ['uint64', ['int32']], - extism_output_get: ['uint8*', ['int32']], + extism_context_new: [context, []], + extism_context_free: ['void', [context]], + extism_plugin_new: ['int32', [context, 'string', 'uint64', 'bool']], + extism_plugin_update: ['bool', [context, 'int32', 'string', 'uint64', 'bool']], + extism_error: ['char*', [context, 'int32']], + extism_plugin_call: ['int32', [context, 'int32', 'string', 'string', 'uint64']], + extism_plugin_output_length: ['uint64', [context, 'int32']], + extism_plugin_output_data: ['uint8*', [context, 'int32']], extism_log_file: ['bool', ['string', 'char*']], - extism_function_exists: ['bool', ['int32', 'string']], - extism_plugin_config: ['void', ['int32', 'char*', 'uint64']], + extism_plugin_function_exists: ['bool', [context, 'int32', 'string']], + extism_plugin_config: ['void', [context, 'int32', 'char*', 'uint64']], + extism_plugin_free: ['void', [context, 'int32']], + extism_context_reset: ['void', [context]], }; function locate(paths) { @@ -41,62 +45,135 @@ if (process.env.EXTISM_PATH) { var lib = locate(searchPath); -export function set_log_file(filename, level = null) { +// Set the log file and level +export function setLogFile(filename, level = null) { lib.extism_log_file(filename, level); } +const pluginRegistry = new FinalizationRegistry(({ id, pointer }) => { + lib.extism_plugin_free(pointer, id); +}); + + +const contextRegistry = new FinalizationRegistry((pointer) => { + lib.extism_context_free(pointer); +}); + +// Context manages plugins +export class Context { + constructor() { + this.pointer = lib.extism_context_new(); + + contextRegistry.register(this, this.pointer, this); + } + + // Create a new plugin, optionally enabling WASI + plugin(data, wasi = false, config = null) { + return new Plugin(this, data, wasi, config); + } + + // Free a context, this should be called when it is + // no longer needed + free() { + contextRegistry.unregister(this); + + lib.extism_context_free(this.pointer); + this.pointer = null; + } + + // Remove all registered plugins + reset() { + lib.extism_context_reset(this.pointer); + } +} + +export async function withContext(f) { + let ctx = new Context(); + + try { + let x = await f(ctx); + ctx.free(); + return x; + } catch (err) { + ctx.free(); + throw err; + } + +} + +// Plugin provides an interface for calling WASM functions export class Plugin { - constructor(data, wasi = false, config = null) { + constructor(ctx, data, wasi = false, config = null) { if (typeof data === 'object' && data.wasm) { data = JSON.stringify(data); } - let plugin = lib.extism_plugin_register(data, data.length, wasi); + let plugin = lib.extism_plugin_new(ctx.pointer, data, data.length, wasi); if (plugin < 0) { - throw 'Unable to load plugin'; + var err = lib.extism_error(ctx.pointer, -1); + if (err.length == 0) { + throw "extism_context_plugin failed"; + } + throw `Unable to load plugin: ${err.toString()}`; } this.id = plugin; + this.ctx = ctx; + pluginRegistry.register(this, { id: this.id, pointer: this.ctx.pointer }, this); if (config != null) { let s = JSON.stringify(config); - lib.extism_plugin_config(this.id, s, s.length); + lib.extism_plugin_config(this.ctx.pointer, this.id, s, s.length); } } + // Update an existing plugin with new WASM or manifest update(data, wasi = false, config = null) { if (typeof data === 'object' && data.wasm) { data = JSON.stringify(data); } - const ok = lib.extism_plugin_update(this.id, data, data.length, wasi); + const ok = lib.extism_plugin_update(this.ctx.pointer, this.id, data, data.length, wasi); if (!ok) { - return false; + var err = lib.extism_error(this.ctx.pointer, -1); + if (err.length == 0) { + throw "extism_plugin_update failed"; + } + throw `Unable to update plugin: ${err.toString()}`; } if (config != null) { let s = JSON.stringify(config); - lib.extism_plugin_config(this.id, s, s.length); + lib.extism_plugin_config(this.ctx.pointer, this.id, s, s.length); } - - return true; } - function_exists(name) { - return lib.extism_function_exists(this.id, name); + // Check if a function exists + functionExists(name) { + return lib.extism_plugin_function_exists(this.ctx.pointer, this.id, name); } + // Call a function by name with the given input async call(name, input) { return new Promise((resolve, reject) => { - var rc = lib.extism_call(this.id, name, input, input.length); + var rc = lib.extism_plugin_call(this.ctx.pointer, this.id, name, input, input.length); if (rc != 0) { - var err = lib.extism_error(this.id); + var err = lib.extism_error(this.ctx.pointer, this.id); if (err.length == 0) { - reject(`extism_call: "${name}" failed`); + reject(`extism_plugin_call: "${name}" failed`); } reject(`Plugin error: ${err.toString()}, code: ${rc}`); } - var out_len = lib.extism_output_length(this.id); - var buf = Buffer.from(lib.extism_output_get(this.id).buffer, 0, out_len); + var out_len = lib.extism_plugin_output_length(this.ctx.pointer, this.id); + var buf = Buffer.from(lib.extism_plugin_output_data(this.ctx.pointer, this.id).buffer, 0, out_len); resolve(buf); }); } + + // Free a plugin, this should be called when the plugin is no longer needed + free() { + pluginRegistry.unregister(this); + lib.extism_plugin_free(this.ctx.pointer, this.id); + this.id = -1; + } } + + diff --git a/ocaml/bin/main.ml b/ocaml/bin/main.ml index 84875a9..affe308 100644 --- a/ocaml/bin/main.ml +++ b/ocaml/bin/main.ml @@ -4,7 +4,9 @@ let () = let input = if Array.length Sys.argv > 1 then Sys.argv.(1) else "this is a test" in + let ctx = Context.create () in let manifest = Manifest.v [ Manifest.file "../wasm/code.wasm" ] in - let plugin = Extism.register_manifest manifest in + let plugin = Extism.of_manifest ctx manifest |> Result.get_ok in let res = Extism.call plugin ~name:"count_vowels" input |> Result.get_ok in - print_endline res + print_endline res; + Context.free ctx diff --git a/ocaml/lib/extism.ml b/ocaml/lib/extism.ml index 426961e..e01c823 100644 --- a/ocaml/lib/extism.ml +++ b/ocaml/lib/extism.ml @@ -39,41 +39,56 @@ module Bindings = struct open Ctypes let fn = Foreign.foreign ~from ~release_runtime_lock:true + let context = ptr void + let extism_context_new = fn "extism_context_new" (void @-> returning context) + let extism_context_free = fn "extism_context_free" (context @-> returning void) - let extism_plugin_register = - fn "extism_plugin_register" - (string @-> uint64_t @-> bool @-> returning int32_t) + let extism_plugin_new = + fn "extism_plugin_new" + (context @-> string @-> uint64_t @-> bool @-> returning int32_t) let extism_plugin_update = fn "extism_plugin_update" - (int32_t @-> string @-> uint64_t @-> bool @-> returning bool) + (context @-> int32_t @-> string @-> uint64_t @-> bool @-> returning bool) let extism_plugin_config = fn "extism_plugin_config" - (int32_t @-> string @-> uint64_t @-> returning bool) + (context @-> int32_t @-> string @-> uint64_t @-> returning bool) - let extism_call = - fn "extism_call" - (int32_t @-> string @-> ptr char @-> uint64_t @-> returning int32_t) + let extism_plugin_call = + fn "extism_plugin_call" + (context @-> int32_t @-> string @-> ptr char @-> uint64_t + @-> returning int32_t) - let extism_call_s = - fn "extism_call" - (int32_t @-> string @-> string @-> uint64_t @-> returning int32_t) + let extism_plugin_call_s = + fn "extism_plugin_call" + (context @-> int32_t @-> string @-> string @-> uint64_t + @-> returning int32_t) - let extism_error = fn "extism_error" (int32_t @-> returning string_opt) + let extism_error = + fn "extism_error" (context @-> int32_t @-> returning string_opt) - let extism_output_length = - fn "extism_output_length" (int32_t @-> returning uint64_t) + let extism_plugin_output_length = + fn "extism_plugin_output_length" (context @-> int32_t @-> returning uint64_t) - let extism_output_get = - fn "extism_output_get" (int32_t @-> returning (ptr char)) + let extism_plugin_output_data = + fn "extism_plugin_output_data" (context @-> int32_t @-> returning (ptr char)) let extism_log_file = fn "extism_log_file" (string @-> string_opt @-> returning bool) + + let extism_plugin_free = + fn "extism_plugin_free" (context @-> int32_t @-> returning void) + + let extism_context_reset = + fn "extism_context_reset" (context @-> returning void) + + let extism_plugin_function_exists = + fn "extism_plugin_function_exists" + (context @-> int32_t @-> string @-> returning bool) end type error = [ `Msg of string ] -type t = { id : int32 } module Manifest = struct type memory = { max : int option [@yojson.option] } [@@deriving yojson] @@ -139,59 +154,98 @@ module Manifest = struct let json t = yojson_of_t t |> Yojson.Safe.to_string end -exception Failed_to_load_plugin +module Context = struct + type t = { mutable pointer : unit Ctypes.ptr } -let set_log_file ?level filename = Bindings.extism_log_file filename level + let create () = + let ptr = Bindings.extism_context_new () in + let t = { pointer = ptr } in + Gc.finalise (fun { pointer } -> Bindings.extism_context_free pointer) t; + t + + let free ctx = + let () = Bindings.extism_context_free ctx.pointer in + ctx.pointer <- Ctypes.null + + let reset ctx = Bindings.extism_context_reset ctx.pointer +end + +type t = { id : int32; ctx : Context.t } + +let with_context f = + let ctx = Context.create () in + let x = + try f ctx + with exc -> + Context.free ctx; + raise exc + in + Context.free ctx; + x let set_config plugin config = match config with | Some config -> let config = Manifest.yojson_of_config config |> Yojson.Safe.to_string in let _ = - Bindings.extism_plugin_config plugin config + Bindings.extism_plugin_config plugin.ctx.pointer plugin.id config (Unsigned.UInt64.of_int (String.length config)) in () | None -> () -let register ?config ?(wasi = false) wasm = +let free t = + if not (Ctypes.is_null t.ctx.pointer) then + Bindings.extism_plugin_free t.ctx.pointer t.id + +let plugin ?config ?(wasi = false) ctx wasm = let id = - Bindings.extism_plugin_register wasm + Bindings.extism_plugin_new ctx.Context.pointer wasm (Unsigned.UInt64.of_int (String.length wasm)) wasi in - if id < 0l then raise Failed_to_load_plugin; - set_config id config; - { id } + if id < 0l then + match Bindings.extism_error ctx.pointer (-1l) with + | None -> Error (`Msg "extism_plugin_call failed") + | Some msg -> Error (`Msg msg) + else + let t = { id; ctx } in + let () = set_config t config in + let () = Gc.finalise free t in + Ok t -let register_manifest ?config ?wasi manifest = +let of_manifest ?config ?wasi ctx manifest = let data = Manifest.json manifest in - register ?config ?wasi data + plugin ctx ?config ?wasi data -let update { id } ?config ?(wasi = false) wasm = +let update plugin ?config ?(wasi = false) wasm = + let { id; ctx } = plugin in let ok = - Bindings.extism_plugin_update id wasm + Bindings.extism_plugin_update ctx.pointer id wasm (Unsigned.UInt64.of_int (String.length wasm)) wasi in - if ok then - let () = set_config id config in - true - else false + if not ok then + match Bindings.extism_error ctx.pointer (-1l) with + | None -> Error (`Msg "extism_plugin_update failed") + | Some msg -> Error (`Msg msg) + else + let () = set_config plugin config in + Ok () let update_manifest plugin ?config ?wasi manifest = let data = Manifest.json manifest in update plugin ?config ?wasi data -let call' f { id } ~name input len = - let rc = f id name input len in +let call' f { id; ctx } ~name input len = + let rc = f ctx.pointer id name input len in if rc <> 0l then - match Bindings.extism_error id with - | None -> Error (`Msg "extism_call failed") + match Bindings.extism_error ctx.pointer id with + | None -> Error (`Msg "extism_plugin_call failed") | Some msg -> Error (`Msg msg) else - let out_len = Bindings.extism_output_length id in - let ptr = Bindings.extism_output_get id in + let out_len = Bindings.extism_plugin_output_length ctx.pointer id in + let ptr = Bindings.extism_plugin_output_data ctx.pointer id in let buf = Ctypes.bigarray_of_ptr Ctypes.array1 (Unsigned.UInt64.to_int out_len) @@ -199,12 +253,17 @@ let call' f { id } ~name input len = in Ok buf -let call_bigstring t ~name input = +let call_bigstring (t : t) ~name input = let len = Unsigned.UInt64.of_int (Bigstringaf.length input) in let ptr = Ctypes.bigarray_start Ctypes.array1 input in - call' Bindings.extism_call t ~name ptr len + call' Bindings.extism_plugin_call t ~name ptr len -let call t ~name input = +let call (t : t) ~name input = let len = String.length input in - call' Bindings.extism_call_s t ~name input (Unsigned.UInt64.of_int len) + call' Bindings.extism_plugin_call_s t ~name input (Unsigned.UInt64.of_int len) |> Result.map Bigstringaf.to_string + +let function_exists { id; ctx } name = + Bindings.extism_plugin_function_exists ctx.pointer id name + +let set_log_file ?level filename = Bindings.extism_log_file filename level diff --git a/ocaml/lib/extism.mli b/ocaml/lib/extism.mli index af38b07..522ca27 100644 --- a/ocaml/lib/extism.mli +++ b/ocaml/lib/extism.mli @@ -1,8 +1,6 @@ type t type error = [`Msg of string] -exception Failed_to_load_plugin - module Manifest : sig type memory = { max : int option } [@@deriving yojson] type wasm_file = { @@ -42,10 +40,21 @@ module Manifest : sig val json: t -> string end +module Context : sig + type t + + val create: unit -> t + val free: t -> unit + val reset: t -> unit +end + +val with_context : (Context.t -> 'a) -> 'a val set_log_file: ?level:string -> string -> bool -val register: ?config:(string * string) list -> ?wasi:bool -> string -> t -val register_manifest: ?config:(string * string) list -> ?wasi:bool -> Manifest.t -> t -val update: t -> ?config:(string * string) list -> ?wasi:bool -> string -> bool -val update_manifest: t -> ?config:(string * string) list -> ?wasi:bool -> Manifest.t -> bool +val plugin: ?config:(string * string) list -> ?wasi:bool -> Context.t -> string -> (t, [`Msg of string]) result +val of_manifest: ?config:(string * string) list -> ?wasi:bool -> Context.t -> Manifest.t -> (t, [`Msg of string]) result +val update: t -> ?config:(string * string) list -> ?wasi:bool -> string -> (unit, [`Msg of string]) result +val update_manifest: t -> ?config:(string * string) list -> ?wasi:bool -> Manifest.t -> (unit, [`Msg of string]) result val call_bigstring: t -> name:string -> Bigstringaf.t -> (Bigstringaf.t, error) result val call: t -> name:string -> string -> (string, error) result +val free: t -> unit +val function_exists: t -> string -> bool diff --git a/php/example/index.php b/php/example/index.php index 95aa9c7..84e6c82 100644 --- a/php/example/index.php +++ b/php/example/index.php @@ -2,8 +2,9 @@ require_once __DIR__ . '/vendor/autoload.php'; +$ctx = new \Extism\Context(); $wasm = file_get_contents("../../wasm/code.wasm"); -$plugin = new \Extism\Plugin($wasm); +$plugin = new \Extism\Plugin($ctx, $wasm); $output = $plugin->call("count_vowels", "this is an example"); $json = json_decode(pack('C*', ...$output)); diff --git a/php/src/Context.php b/php/src/Context.php new file mode 100644 index 0000000..4eea06c --- /dev/null +++ b/php/src/Context.php @@ -0,0 +1,53 @@ +pointer = $lib->extism_context_new(); + $this->lib = $lib; + } + + public function __destruct() + { + global $lib; + + $lib->extism_context_free($this->pointer); + } + + + public function reset() + { + global $lib; + + $lib->extism_context_reset($this->pointer); + } +} + + +function set_log_file($filename, $level) +{ + global $lib; + + $lib->extism_log_file($filename, $level); +} diff --git a/php/src/Plugin.php b/php/src/Plugin.php index da12d3c..49d4389 100644 --- a/php/src/Plugin.php +++ b/php/src/Plugin.php @@ -6,28 +6,19 @@ require_once "vendor/autoload.php"; require_once "generate.php"; require_once "ExtismLib.php"; -$lib = new \ExtismLib(\ExtismLib::SOFILE); -if ($lib == null) { - throw new Exception("Extism: failed to create new runtime instance"); -} - class Plugin { private $lib; + private $context; private $wasi; private $config; private $id; - public function __construct($data, $wasi = false, $config = null) + public function __construct($ctx, $data, $wasi = false, $config = null) { - global $lib; - - if ($lib == null) { - $lib = new \ExtismLib(\ExtismLib::SOFILE); - } - $this->lib = $lib; + $this->lib = $ctx->lib; $this->wasi = $wasi; $this->config = $config; @@ -40,21 +31,34 @@ class Plugin $data = string_to_bytes($data); } - $id = $this->lib->extism_plugin_register($data, count($data), (int)$wasi); + $id = $this->lib->extism_plugin_new($ctx->pointer, $data, count($data), (int)$wasi); if ($id < 0) { - throw new Exception("Extism: unable to load plugin"); + $err = $this->lib->extism_error($ctx->pointer, -1); + throw new Exception("Extism: unable to load plugin: " . $err); } $this->id = $id; + $this->context = $ctx; if ($config != null) { $cfg = string_to_bytes(json_encode(config)); - $this->lib->extism_plugin_config($this->id, $cfg, count($cfg)); + $this->lib->extism_plugin_config($ctx->pointer, $this->id, $cfg, count($cfg)); } } + + public function __destruct() { + $this->lib->extism_plugin_free($this->context->pointer, $this->id); + $this->id = -1; + } public function getId() { return $this->id; } + + + public function functionExists($name) + { + return $this->lib->extism_plugin_function_exists($this->context->pointer, $this->id, $name); + } public function call($name, $input = null) { @@ -62,19 +66,19 @@ class Plugin $input = string_to_bytes($input); } - $rc = $this->lib->extism_call($this->id, $name, $input, count($input)); + $rc = $this->lib->extism_plugin_call($this->context->pointer, $this->id, $name, $input, count($input)); if ($rc != 0) { $msg = "code = " . $rc; - $err = $this->lib->extism_error($this->id); + $err = $this->lib->extism_error($this->context->pointer, $this->id); if ($err) { $msg = $msg . ", error = " . $err; } throw new Execption("Extism: call to '".$name."' failed with " . $msg); } - $length = $this->lib->extism_output_length($this->id); + $length = $this->lib->extism_plugin_output_length($this->context->pointer, $this->id); - $buf = $this->lib->extism_output_get($this->id); + $buf = $this->lib->extism_plugin_output_data($this->context->pointer, $this->id); $ouput = []; $data = $buf->getData(); @@ -94,17 +98,16 @@ class Plugin $data = string_to_bytes($data); } - $ok = $this->lib->extism_plugin_update($this->id, $data, count($data), (int)$wasi); + $ok = $this->lib->extism_plugin_update($this->context->pointer, $this->id, $data, count($data), (int)$wasi); if (!$ok) { - return false; + $err = $this->lib->extism_error($this->context->pointer, -1); + throw new Exception("Extism: unable to update plugin: " . $err); } if ($config != null) { $config = json_encode($config); - $this->lib->extism_plugin_config($this->id, $config, strlen($config)); + $this->lib->extism_plugin_config($this->context->pointer, $this->id, $config, strlen($config)); } - - return true; } } diff --git a/php/src/generate.php b/php/src/generate.php index a454415..075b501 100644 --- a/php/src/generate.php +++ b/php/src/generate.php @@ -1,19 +1,12 @@ include("extism.h") - ->showWarnings(false) - ->codeGen('ExtismLib', __DIR__.'/ExtismLib.php'); - } catch (Exception $e) { - continue; - } - } +function generate() { + return (new FFIMe\FFIMe("libextism.".soext())) + ->include("extism.h") + ->showWarnings(false) + ->codeGen('ExtismLib', __DIR__.'/ExtismLib.php'); } function soext() { @@ -31,6 +24,6 @@ function soext() { } if (!file_exists(__DIR__."/ExtismLib.php")) { - generate($search_path); + generate(); } diff --git a/python/example.py b/python/example.py index 29e1947..7cbee99 100644 --- a/python/example.py +++ b/python/example.py @@ -4,22 +4,24 @@ import json import hashlib sys.path.append(".") -from extism import Plugin +from extism import Plugin, Context if len(sys.argv) > 1: data = sys.argv[1].encode() else: data = b"some data from python!" -wasm = open("../wasm/code.wasm", 'rb').read() -hash = hashlib.sha256(wasm).hexdigest() -config = {"wasm": [{"data": wasm, "hash": hash}], "memory": {"max": 5}} +# a Context provides a scope for plugins to be managed within. creating multiple contexts +# is expected and groups plugins based on source/tenant/lifetime etc. +with Context() as context: + wasm = open("../wasm/code.wasm", 'rb').read() + hash = hashlib.sha256(wasm).hexdigest() + config = {"wasm": [{"data": wasm, "hash": hash}], "memory": {"max": 5}} -plugin = Plugin(config) - -# Call `count_vowels` -j = json.loads(plugin.call("count_vowels", data)) -print("Number of vowels:", j["count"]) + plugin = context.plugin(config) + # Call `count_vowels` + j = json.loads(plugin.call("count_vowels", data)) + print("Number of vowels:", j["count"]) # Compare against Python implementation diff --git a/python/extism/__init__.py b/python/extism/__init__.py index a122391..c355fb9 100644 --- a/python/extism/__init__.py +++ b/python/extism/__init__.py @@ -1 +1 @@ -from .extism import Error, Plugin, set_log_file +from .extism import Error, Plugin, set_log_file, Context diff --git a/python/extism/extism.py b/python/extism/extism.py index 6131957..1d28e08 100644 --- a/python/extism/extism.py +++ b/python/extism/extism.py @@ -9,6 +9,7 @@ from typing import Union class Error(Exception): + '''Extism error type''' pass @@ -87,6 +88,7 @@ class Base64Encoder(json.JSONEncoder): def set_log_file(file, level=ffi.NULL): + '''Sets the log file and level, this is a global configuration''' if isinstance(level, str): level = level.encode() lib.extism_log_file(file.encode(), level) @@ -105,47 +107,115 @@ def _wasm(plugin): return wasm +class Context: + '''Context is used to store and manage plugins''' + + def __init__(self): + self.pointer = lib.extism_context_new() + + def __del__(self): + lib.extism_context_free(self.pointer) + self.pointer = ffi.NULL + + def __enter__(self): + return self + + def __exit__(self, type, exc, traceback): + self.__del__() + + def reset(self): + '''Remove all registered plugins''' + lib.extism_context_reset(self.pointer) + + def plugin(self, plugin: Union[str, bytes, dict], wasi=False, config=None): + '''Register a new plugin from a WASM module or JSON encoded manifest''' + return Plugin(self, plugin, wasi, config) + + class Plugin: + '''Plugin is used to call WASM functions''' def __init__(self, + context: Context, plugin: Union[str, bytes, dict], wasi=False, config=None): wasm = _wasm(plugin) # Register plugin - self.plugin = lib.extism_plugin_register(wasm, len(wasm), wasi) + self.plugin = lib.extism_plugin_new(context.pointer, wasm, len(wasm), + wasi) + + self.ctx = context + + if self.plugin < 0: + error = lib.extism_error(-1) + if error != ffi.NULL: + raise Error(ffi.string(error).decode()) + raise Error("Unable to register plugin") if config is not None: s = json.dumps(config).encode() lib.extism_plugin_config(self.plugin, s, len(s)) def update(self, plugin: Union[str, bytes, dict], wasi=False, config=None): + '''Update a plugin with a new WASM module or manifest''' wasm = _wasm(plugin) - ok = lib.extism_plugin_update(self.plugin, wasm, len(wasm), wasi) + ok = lib.extism_plugin_update(self.ctx.pointer, self.plugin, wasm, + len(wasm), wasi) if not ok: - return False + error = lib.extism_error(self.ctx.pointer, -1) + if error != ffi.NULL: + raise Error(ffi.string(error).decode()) + raise Error("Unable to update plugin") if config is not None: s = json.dumps(config).encode() - lib.extism_plugin_config(self.plugin, s, len(s)) - return True + lib.extism_plugin_config(self.ctx.pointer, self.plugin, s, len(s)) def _check_error(self, rc): if rc != 0: - error = lib.extism_error(self.plugin) + error = lib.extism_error(self.ctx.pointer, self.plugin) if error != ffi.NULL: raise Error(ffi.string(error).decode()) raise Error(f"Error code: {rc}") - def call(self, name: str, data: Union[str, bytes], parse=bytes) -> bytes: + def function_exists(self, name: str) -> bool: + '''Returns true if the given function exists''' + return lib.extism_plugin_function_exists(self.ctx.pointer, self.plugin, + name.encode()) + + def call(self, name: str, data: Union[str, bytes], parse=bytes): + ''' + Call a function by name with the provided input data + + The `parse` argument can be used to transform the output buffer into + your expected type. It expects a function that takes a buffer as the + only argument + ''' if isinstance(data, str): data = data.encode() self._check_error( - lib.extism_call(self.plugin, name.encode(), data, len(data))) - out_len = lib.extism_output_length(self.plugin) - out_buf = lib.extism_output_get(self.plugin) + lib.extism_plugin_call(self.ctx.pointer, self.plugin, + name.encode(), data, len(data))) + out_len = lib.extism_plugin_output_length(self.ctx.pointer, + self.plugin) + out_buf = lib.extism_plugin_output_data(self.ctx.pointer, self.plugin) buf = ffi.buffer(out_buf, out_len) if parse is None: return buf return parse(buf) + + def __del__(self): + if not hasattr(self, 'ctx'): + return + if self.ctx.pointer == ffi.NULL: + return + lib.extism_plugin_free(self.ctx.pointer, self.plugin) + self.plugin = -1 + + def __enter__(self): + return self + + def __exit__(self, type, exc, traceback): + self.__del__() diff --git a/ruby/example.rb b/ruby/example.rb index 2b509b9..5d36077 100644 --- a/ruby/example.rb +++ b/ruby/example.rb @@ -1,9 +1,16 @@ require './lib/extism' require 'json' -manifest = { - :wasm => [{:path => "../wasm/code.wasm"}] +# a Context provides a scope for plugins to be managed within. creating multiple contexts +# is expected and groups plugins based on source/tenant/lifetime etc. +ctx = Extism::Context.new +Extism::with_context {|ctx| + manifest = { + :wasm => [{:path => "../wasm/code.wasm"}] + } + + plugin = ctx.plugin(manifest) + res = JSON.parse(plugin.call("count_vowels", ARGV[0] || "this is a test")) + + puts res['count'] } -plugin = Extism::Plugin.new(manifest) -res = JSON.parse(plugin.call("count_vowels", ARGV[0] || "this is a test")) -puts res['count'] diff --git a/ruby/lib/extism.rb b/ruby/lib/extism.rb index dd88e8b..550461d 100644 --- a/ruby/lib/extism.rb +++ b/ruby/lib/extism.rb @@ -5,19 +5,25 @@ module Extism module C extend FFI::Library ffi_lib "extism" - attach_function :extism_plugin_register, [:pointer, :uint64, :bool], :int32 - attach_function :extism_plugin_update, [:int32, :pointer, :uint64, :bool], :bool - attach_function :extism_error, [:int32], :string - attach_function :extism_call, [:int32, :string, :pointer, :uint64], :int32 - attach_function :extism_output_length, [:int32], :uint64 - attach_function :extism_output_get, [:int32], :pointer + attach_function :extism_context_new, [], :pointer + attach_function :extism_context_free, [:pointer], :void + attach_function :extism_plugin_new, [:pointer, :pointer, :uint64, :bool], :int32 + attach_function :extism_plugin_update, [:pointer, :int32, :pointer, :uint64, :bool], :bool + attach_function :extism_error, [:pointer, :int32], :string + attach_function :extism_plugin_call, [:pointer, :int32, :string, :pointer, :uint64], :int32 + attach_function :extism_plugin_function_exists, [:pointer, :int32, :string], :bool + attach_function :extism_plugin_output_length, [:pointer, :int32], :uint64 + attach_function :extism_plugin_output_data, [:pointer, :int32], :pointer attach_function :extism_log_file, [:string, :pointer], :void + attach_function :extism_plugin_free, [:pointer, :int32], :void + attach_function :extism_context_reset, [:pointer], :void end class Error < StandardError end + # Set log file and level, this is a global configuration def self.set_log_file(name, level=nil) if level then level = FFI::MemoryPointer::from_string(level) @@ -25,54 +31,147 @@ module Extism C.extism_log_file(name, level) end + $PLUGINS = {} + $FREE_PLUGIN = proc { |id| + x = $PLUGINS[id] + if !x.nil? then + C.extism_plugin_free(x[:context].pointer, x[:plugin]) + $PLUGINS.delete(id) + end + } + + $CONTEXTS = {} + $FREE_CONTEXT = proc { |id| + x = $CONTEXTS[id] + if !x.nil? then + C.extism_context_free($CONTEXTS[id]) + $CONTEXTS.delete(id) + end + } + + # Context is used to manage plugins + class Context + attr_accessor :pointer + + def initialize + @pointer = C.extism_context_new() + $CONTEXTS[self.object_id] = @pointer + ObjectSpace.define_finalizer(self, $FREE_CONTEXT) + end + + # Remove all registered plugins + def reset + C.extism_context_reset(@pointer) + end + + # Free the context, this should be called when it is no longer needed + def free + if @pointer.nil? then + return + end + $CONTEXTS.delete(self.object_id) + C.extism_context_free(@pointer) + @pointer = nil + end + + # Create a new plugin from a WASM module or JSON encoded manifest + def plugin(wasm, wasi=false, config=nil) + return Plugin.new(self, wasm, wasi, config) + end + end + + + def self.with_context(&block) + ctx = Context.new + begin + x = block.call(ctx) + return x + ensure + ctx.free + end + end + class Plugin - def initialize(wasm, wasi=false, config=nil) + def initialize(context, wasm, wasi=false, config=nil) if wasm.class == Hash then wasm = JSON.generate(wasm) end code = FFI::MemoryPointer.new(:char, wasm.bytesize) code.put_bytes(0, wasm) - @plugin = C.extism_plugin_register(code, wasm.bytesize, wasi) - + @plugin = C.extism_plugin_new(context.pointer, code, wasm.bytesize, wasi) + if @plugin < 0 then + err = C.extism_error(-1) + if err&.empty? then + raise Error.new "extism_plugin_new failed" + else raise Error.new err + end + end + @context = context + $PLUGINS[self.object_id] = {:plugin => @plugin, :context => context} + ObjectSpace.define_finalizer(self, $FREE_PLUGIN) if config != nil and @plugin >= 0 then s = JSON.generate(config) ptr = FFI::MemoryPointer::from_string(s) - C.extism_plugin_config(@plugin, ptr, s.bytesize) + C.extism_plugin_config(@context.pointer, @plugin, ptr, s.bytesize) end end + # Update a plugin with new WASM module or manifest def update(wasm, wasi=false, config=nil) if wasm.class == Hash then wasm = JSON.generate(wasm) end code = FFI::MemoryPointer.new(:char, wasm.bytesize) code.put_bytes(0, wasm) - ok = C.extism_plugin_update(@plugin, code, wasm.bytesize, wasi) - if ok then - if config != nil then - s = JSON.generate(config) - ptr = FFI::MemoryPointer::from_string(s) - C.extism_plugin_config(@plugin, ptr, s.bytesize) + ok = C.extism_plugin_update(@context.pointer, @plugin, code, wasm.bytesize, wasi) + if !ok then + err = C.extism_error(-1) + if err&.empty? then + raise Error.new "extism_plugin_update failed" + else raise Error.new err end end - return ok + + if config != nil then + s = JSON.generate(config) + ptr = FFI::MemoryPointer::from_string(s) + C.extism_plugin_config(@context.pointer, @plugin, ptr, s.bytesize) + end end + # Check if a function exists + def function_exists(name) + return C.extism_function_exists(@context.pointer, @plugin, name) + end + + # Call a function by name def call(name, data, &block) # If no block was passed then use Pointer::read_string block ||= ->(buf, len){ buf.read_string(len) } input = FFI::MemoryPointer::from_string(data) - rc = C.extism_call(@plugin, name, input, data.bytesize) + rc = C.extism_plugin_call(@context.pointer, @plugin, name, input, data.bytesize) if rc != 0 then - err = C.extism_error(@plugin) + err = C.extism_error(@context.pointer, @plugin) if err&.empty? then raise Error.new "extism_call failed" else raise Error.new err end end - out_len = C.extism_output_length(@plugin) - buf = C.extism_output_get(@plugin) + out_len = C.extism_plugin_output_length(@context.pointer, @plugin) + buf = C.extism_plugin_output_data(@context.pointer, @plugin) return block.call(buf, out_len) end + + # Free a plugin, this should be called when the plugin is no longer needed + def free + if @context.pointer.nil? then + return + end + + $PLUGINS.delete(self.object_id) + C.extism_plugin_free(@context.pointer, @plugin) + @plugin = -1 + end + end end diff --git a/runtime/build.rs b/runtime/build.rs index e798e8c..1846f4d 100644 --- a/runtime/build.rs +++ b/runtime/build.rs @@ -12,6 +12,7 @@ fn main() { .with_pragma_once(true) .rename_item("Size", "ExtismSize") .rename_item("PluginIndex", "ExtismPlugin") + .rename_item("Context", "ExtismContext") .generate() { bindings.write_to_file("extism.h"); diff --git a/runtime/extism.h b/runtime/extism.h index b31f798..82fbcf5 100644 --- a/runtime/extism.h +++ b/runtime/extism.h @@ -3,30 +3,106 @@ #include #include +/** + * A `Context` is used to store and manage plugins + */ +typedef struct ExtismContext ExtismContext; + typedef int32_t ExtismPlugin; typedef uint64_t ExtismSize; -ExtismPlugin extism_plugin_register(const uint8_t *wasm, ExtismSize wasm_size, bool with_wasi); +/** + * Create a new context + */ +struct ExtismContext *extism_context_new(void); -bool extism_plugin_update(ExtismPlugin index, +/** + * Free a context + */ +void extism_context_free(struct ExtismContext *ctx); + +/** + * Create a new plugin + * + * `wasm`: is a WASM module (wat or wasm) or a JSON encoded manifest + * `wasm_size`: the length of the `wasm` parameter + * `with_wasi`: enables/disables WASI + */ +ExtismPlugin extism_plugin_new(struct ExtismContext *ctx, + const uint8_t *wasm, + ExtismSize wasm_size, + bool with_wasi); + +/** + * Update a plugin, keeping the existing ID + * + * Similar to `extism_plugin_new` but takes an `index` argument to specify + * which plugin to update + * + * Memory for this plugin will be reset upon update + */ +bool extism_plugin_update(struct ExtismContext *ctx, + ExtismPlugin index, const uint8_t *wasm, ExtismSize wasm_size, bool with_wasi); -bool extism_plugin_config(ExtismPlugin plugin, const uint8_t *json, ExtismSize json_size); +/** + * Remove a plugin from the registry and free associated memory + */ +void extism_plugin_free(struct ExtismContext *ctx, ExtismPlugin plugin); -bool extism_function_exists(ExtismPlugin plugin, const char *func_name); +/** + * Remove all plugins from the registry + */ +void extism_context_reset(struct ExtismContext *ctx); -int32_t extism_call(ExtismPlugin plugin_id, - const char *func_name, - const uint8_t *data, - ExtismSize data_len); +/** + * Update plugin config values, this will merge with the existing values + */ +bool extism_plugin_config(struct ExtismContext *ctx, + ExtismPlugin plugin, + const uint8_t *json, + ExtismSize json_size); -const char *extism_error(ExtismPlugin plugin); +/** + * Returns true if `func_name` exists + */ +bool extism_plugin_function_exists(struct ExtismContext *ctx, + ExtismPlugin plugin, + const char *func_name); -ExtismSize extism_output_length(ExtismPlugin plugin); +/** + * Call a function + * + * `func_name`: is the function to call + * `data`: is the input data + * `data_len`: is the length of `data` + */ +int32_t extism_plugin_call(struct ExtismContext *ctx, + ExtismPlugin plugin_id, + const char *func_name, + const uint8_t *data, + ExtismSize data_len); -const uint8_t *extism_output_get(ExtismPlugin plugin); +/** + * Get the error associated with a `Context` or `Plugin`, if `plugin` is `-1` then the context + * error will be returned + */ +const char *extism_error(struct ExtismContext *ctx, ExtismPlugin plugin); +/** + * Get the length of a plugin's output data + */ +ExtismSize extism_plugin_output_length(struct ExtismContext *ctx, ExtismPlugin plugin); + +/** + * Get the length of a plugin's output data + */ +const uint8_t *extism_plugin_output_data(struct ExtismContext *ctx, ExtismPlugin plugin); + +/** + * Set log file and level + */ bool extism_log_file(const char *filename, const char *log_level); diff --git a/runtime/src/context.rs b/runtime/src/context.rs new file mode 100644 index 0000000..2f97a78 --- /dev/null +++ b/runtime/src/context.rs @@ -0,0 +1,74 @@ +use std::collections::BTreeMap; + +use crate::*; + +/// A `Context` is used to store and manage plugins +#[derive(Default)] +pub struct Context { + /// Plugin registry + pub plugins: BTreeMap, + + /// Error message + pub error: Option, + next_id: std::sync::atomic::AtomicI32, + reclaimed_ids: Vec, +} + +impl Context { + /// Create a new context + pub fn new() -> Context { + Context { + plugins: BTreeMap::new(), + error: None, + next_id: std::sync::atomic::AtomicI32::new(0), + reclaimed_ids: Vec::new(), + } + } + + /// Get the next valid plugin ID + pub fn next_id(&mut self) -> Result { + // Make sure we haven't exhausted all plugin IDs, it reach this it would require the machine + // running this code to have a lot of memory - no computer I tested on was able to allocate + // this many plugins. + if self.next_id.load(std::sync::atomic::Ordering::SeqCst) == PluginIndex::MAX { + // Since `Context::remove` collects IDs that have been removed we will + // try to use one of those before returning an error + match self.reclaimed_ids.pop() { + None => { + return Err(anyhow::format_err!( + "All plugin descriptors are in use, unable to allocate a new plugin" + )) + } + Some(x) => return Ok(x), + } + } + + Ok(self + .next_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst)) + } + + /// Set the context error + pub fn set_error(&mut self, e: impl std::fmt::Debug) { + self.error = Some(error_string(e)); + } + + /// Convenience function to set error and return the value passed as the final parameter + pub fn error(&mut self, e: impl std::fmt::Debug, x: T) -> T { + self.set_error(e); + x + } + + /// Get a plugin from the context + pub fn plugin(&mut self, id: PluginIndex) -> Option<&mut Plugin> { + self.plugins.get_mut(&id) + } + + /// Remove a plugin from the context + pub fn remove(&mut self, id: PluginIndex) { + self.plugins.remove(&id); + + // Collect old IDs in case we need to re-use them + self.reclaimed_ids.push(id); + } +} diff --git a/runtime/src/export.rs b/runtime/src/export.rs index b31490e..df9ae47 100644 --- a/runtime/src/export.rs +++ b/runtime/src/export.rs @@ -1,25 +1,9 @@ +/// All the functions in the file are exposed from inside WASM plugins use crate::*; -macro_rules! plugin { - (mut $a:expr) => { - unsafe { (&mut *$a.plugin) } - }; - - ($a:expr) => { - unsafe { (&*$a.plugin) } - }; -} - -macro_rules! memory { - (mut $a:expr) => { - &mut plugin!(mut $a).memory - }; - - ($a:expr) => { - &plugin!($a).memory - }; -} - +/// Get the input length +/// Params: none +/// Returns: i64 (length) pub(crate) fn input_length( caller: Caller, _input: &[Val], @@ -27,9 +11,12 @@ pub(crate) fn input_length( ) -> Result<(), Trap> { let data: &Internal = caller.data(); output[0] = Val::I64(data.input_length as i64); - return Ok(()); + Ok(()) } +/// Load a byte from input +/// Params: i64 (offset) +/// Returns: i32 (byte) pub(crate) fn input_load_u8( caller: Caller, input: &[Val], @@ -43,6 +30,9 @@ pub(crate) fn input_load_u8( Ok(()) } +/// Load an unsigned 64 bit integer from input +/// Params: i64 (offset) +/// Returns: i64 (int) pub(crate) fn input_load_u64( caller: Caller, input: &[Val], @@ -59,6 +49,108 @@ pub(crate) fn input_load_u64( Ok(()) } +/// Store a byte in memory +/// Params: i64 (offset), i32 (byte) +/// Returns: none +pub(crate) fn store_u8( + mut caller: Caller, + input: &[Val], + _output: &mut [Val], +) -> Result<(), Trap> { + let data: &mut Internal = caller.data_mut(); + let byte = input[1].unwrap_i32() as u8; + data.memory_mut() + .store_u8(input[0].unwrap_i64() as usize, byte) + .map_err(|_| Trap::new("Write error"))?; + Ok(()) +} + +/// Load a byte from memory +/// Params: i64 (offset) +/// Returns: i32 (byte) +pub(crate) fn load_u8( + caller: Caller, + input: &[Val], + output: &mut [Val], +) -> Result<(), Trap> { + let data: &Internal = caller.data(); + let byte = data + .memory() + .load_u8(input[0].unwrap_i64() as usize) + .map_err(|_| Trap::new("Read error"))?; + output[0] = Val::I32(byte as i32); + Ok(()) +} + +/// Store an unsigned 32 bit integer in memory +/// Params: i64 (offset), i32 (int) +/// Returns: none +pub(crate) fn store_u32( + mut caller: Caller, + input: &[Val], + _output: &mut [Val], +) -> Result<(), Trap> { + let data: &mut Internal = caller.data_mut(); + let b = input[1].unwrap_i32() as u32; + data.memory_mut() + .store_u32(input[0].unwrap_i64() as usize, b) + .map_err(|_| Trap::new("Write error"))?; + Ok(()) +} + +/// Load an unsigned 32 bit integer from memory +/// Params: i64 (offset) +/// Returns: i32 (int) +pub(crate) fn load_u32( + caller: Caller, + input: &[Val], + output: &mut [Val], +) -> Result<(), Trap> { + let data: &Internal = caller.data(); + let b = data + .memory() + .load_u32(input[0].unwrap_i64() as usize) + .map_err(|_| Trap::new("Read error"))?; + output[0] = Val::I32(b as i32); + Ok(()) +} + +/// Store an unsigned 64 bit integer in memory +/// Params: i64 (offset), i64 (int) +/// Returns: none +pub(crate) fn store_u64( + mut caller: Caller, + input: &[Val], + _output: &mut [Val], +) -> Result<(), Trap> { + let data: &mut Internal = caller.data_mut(); + let b = input[1].unwrap_i64() as u64; + data.memory_mut() + .store_u64(input[0].unwrap_i64() as usize, b) + .map_err(|_| Trap::new("Write error"))?; + Ok(()) +} + +/// Load an unsigned 64 bit integer from memory +/// Params: i64 (offset) +/// Returns: i64 (int) +pub(crate) fn load_u64( + caller: Caller, + input: &[Val], + output: &mut [Val], +) -> Result<(), Trap> { + let data: &Internal = caller.data(); + let byte = data + .memory() + .load_u64(input[0].unwrap_i64() as usize) + .map_err(|_| Trap::new("Read error"))?; + output[0] = Val::I64(byte as i64); + Ok(()) +} + +/// Set output offset and length +/// Params: i64 (offset), i64 (length) +/// Returns: none pub(crate) fn output_set( mut caller: Caller, input: &[Val], @@ -70,18 +162,24 @@ pub(crate) fn output_set( Ok(()) } +/// Allocate bytes +/// Params: i64 (length) +/// Returns: i64 (offset) pub(crate) fn alloc( mut caller: Caller, input: &[Val], output: &mut [Val], ) -> Result<(), Trap> { let data: &mut Internal = caller.data_mut(); - let offs = memory!(mut data).alloc(input[0].unwrap_i64() as _)?; + let offs = data.memory_mut().alloc(input[0].unwrap_i64() as _)?; output[0] = Val::I64(offs.offset as i64); Ok(()) } +/// Free memory +/// Params: i64 (offset) +/// Returns: none pub(crate) fn free( mut caller: Caller, input: &[Val], @@ -89,88 +187,13 @@ pub(crate) fn free( ) -> Result<(), Trap> { let data: &mut Internal = caller.data_mut(); let offset = input[0].unwrap_i64() as usize; - memory!(mut data).free(offset); - Ok(()) -} - -pub(crate) fn store_u8( - mut caller: Caller, - input: &[Val], - _output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let byte = input[1].unwrap_i32() as u8; - memory!(mut data) - .store_u8(input[0].unwrap_i64() as usize, byte) - .map_err(|_| Trap::new("Write error"))?; - Ok(()) -} - -pub(crate) fn load_u8( - mut caller: Caller, - input: &[Val], - output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let byte = memory!(data) - .load_u8(input[0].unwrap_i64() as usize) - .map_err(|_| Trap::new("Read error"))?; - output[0] = Val::I32(byte as i32); - Ok(()) -} - -pub(crate) fn store_u32( - mut caller: Caller, - input: &[Val], - _output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let b = input[1].unwrap_i32() as u32; - memory!(mut data) - .store_u32(input[0].unwrap_i64() as usize, b) - .map_err(|_| Trap::new("Write error"))?; - Ok(()) -} - -pub(crate) fn load_u32( - mut caller: Caller, - input: &[Val], - output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let b = memory!(data) - .load_u32(input[0].unwrap_i64() as usize) - .map_err(|_| Trap::new("Read error"))?; - output[0] = Val::I32(b as i32); - Ok(()) -} - -pub(crate) fn store_u64( - mut caller: Caller, - input: &[Val], - _output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let b = input[1].unwrap_i64() as u64; - memory!(mut data) - .store_u64(input[0].unwrap_i64() as usize, b) - .map_err(|_| Trap::new("Write error"))?; - Ok(()) -} - -pub(crate) fn load_u64( - mut caller: Caller, - input: &[Val], - output: &mut [Val], -) -> Result<(), Trap> { - let data: &mut Internal = caller.data_mut(); - let byte = memory!(data) - .load_u64(input[0].unwrap_i64() as usize) - .map_err(|_| Trap::new("Read error"))?; - output[0] = Val::I64(byte as i64); + data.memory_mut().free(offset); Ok(()) } +/// Set the error message, this can be checked by the host program +/// Params: i64 (offset) +/// Returns: none pub(crate) fn error_set( mut caller: Caller, input: &[Val], @@ -178,23 +201,27 @@ pub(crate) fn error_set( ) -> Result<(), Trap> { let data: &mut Internal = caller.data_mut(); let offset = input[0].unwrap_i64() as usize; - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to error_set")), }; let handle = MemoryBlock { offset, length }; if handle.offset == 0 { - plugin!(mut data).clear_error(); + data.plugin_mut().clear_error(); return Ok(()); } - let buf = memory!(data).get(handle); + let buf = data.memory().ptr(handle); + let buf = unsafe { std::slice::from_raw_parts(buf, length) }; let s = unsafe { std::str::from_utf8_unchecked(buf) }; - plugin!(mut data).set_error(s); + data.plugin_mut().set_error(s); Ok(()) } +/// Get a configuration value +/// Params: i64 (offset) +/// Returns: i64 (offset) pub(crate) fn config_get( mut caller: Caller, input: &[Val], @@ -202,16 +229,24 @@ pub(crate) fn config_get( ) -> Result<(), Trap> { let data: &mut Internal = caller.data_mut(); let offset = input[0].unwrap_i64() as usize; - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to config_get")), }; - let buf = memory!(data).get((offset, length)); + let buf = data.memory().get((offset, length)); let str = unsafe { std::str::from_utf8_unchecked(buf) }; - let val = plugin!(data).manifest.as_ref().config.get(str); + let val = data + .plugin() + .manifest + .as_ref() + .config + .get(str) + .map(|x| x.as_ptr()); let mem = match val { - Some(f) => memory!(mut data).alloc_bytes(f.as_bytes())?, + Some(f) => data + .memory_mut() + .alloc_bytes(unsafe { std::slice::from_raw_parts(f, length) })?, None => { output[0] = Val::I64(0); return Ok(()); @@ -222,6 +257,9 @@ pub(crate) fn config_get( Ok(()) } +/// Get a variable +/// Params: i64 (offset) +/// Returns: i64 (offset) pub(crate) fn var_get( mut caller: Caller, input: &[Val], @@ -229,16 +267,21 @@ pub(crate) fn var_get( ) -> Result<(), Trap> { let data: &mut Internal = caller.data_mut(); let offset = input[0].unwrap_i64() as usize; - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to var_get")), }; - let buf = memory!(data).get((offset, length)); - let str = unsafe { std::str::from_utf8_unchecked(buf) }; - let val = data.vars.get(str); + let val = { + let buf = data.memory().ptr((offset, length)); + let buf = unsafe { std::slice::from_raw_parts(buf, length) }; + let str = unsafe { std::str::from_utf8_unchecked(buf) }; + data.vars.get(str).map(|x| x.as_ptr()) + }; let mem = match val { - Some(f) => memory!(mut data).alloc_bytes(&f)?, + Some(f) => data + .memory_mut() + .alloc_bytes(unsafe { std::slice::from_raw_parts(f, length) })?, None => { output[0] = Val::I64(0); return Ok(()); @@ -249,6 +292,9 @@ pub(crate) fn var_get( Ok(()) } +/// Set a variable, if the value offset is 0 then the provided key will be removed +/// Params: i64 (key offset), i64 (value offset) +/// Returns: none pub(crate) fn var_set( mut caller: Caller, input: &[Val], @@ -269,30 +315,35 @@ pub(crate) fn var_set( } let koffset = input[0].unwrap_i64() as usize; - let klength = match memory!(data).block_length(koffset) { + let klength = match data.memory().block_length(koffset) { Some(x) => x, None => return Err(Trap::new("Invalid offset for key in call to var_set")), }; - let kbuf = memory!(data).get((koffset, klength)); + let kbuf = data.memory().ptr((koffset, klength)); + let kbuf = unsafe { std::slice::from_raw_parts(kbuf, klength) }; let kstr = unsafe { std::str::from_utf8_unchecked(kbuf) }; + // Remove if the value offset is 0 if voffset == 0 { data.vars.remove(kstr); return Ok(()); } - let vlength = match memory!(data).block_length(voffset) { + let vlength = match data.memory().block_length(voffset) { Some(x) => x, None => return Err(Trap::new("Invalid offset for value in call to var_set")), }; - let vbuf = memory!(data).get((voffset, vlength)); + let vbuf = data.memory().get((voffset, vlength)); data.vars.insert(kstr.to_string(), vbuf.to_vec()); Ok(()) } +/// Make an HTTP request +/// Params: i64 (offset to JSON encoded HttpRequest), i64 (offset to body or 0) +/// Returns: i64 (offset) pub(crate) fn http_request( #[allow(unused_mut)] mut caller: Caller, input: &[Val], @@ -313,11 +364,11 @@ pub(crate) fn http_request( let data: &mut Internal = caller.data_mut(); let offset = input[0].unwrap_i64() as usize; - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to http_request")), }; - let buf = memory!(data).get((offset, length)); + let buf = data.memory().get((offset, length)); let req: extism_manifest::HttpRequest = serde_json::from_slice(buf).map_err(|_| Trap::new("Invalid http request"))?; @@ -330,11 +381,11 @@ pub(crate) fn http_request( } let mut res = if body_offset > 0 { - let length = match memory!(data).block_length(body_offset) { + let length = match data.memory().block_length(body_offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to http_request")), }; - let buf = memory!(data).get((offset, length)); + let buf = data.memory().get((offset, length)); r.send_bytes(buf) .map_err(|e| Trap::new(&format!("Request error: {e:?}")))? .into_reader() @@ -348,13 +399,16 @@ pub(crate) fn http_request( res.read_to_end(&mut buf) .map_err(|e| Trap::new(format!("{:?}", e)))?; - let mem = memory!(mut data).alloc_bytes(buf)?; + let mem = data.memory_mut().alloc_bytes(buf)?; output[0] = Val::I64(mem.offset as i64); Ok(()) } } +/// Get the length of an allocated block given the offset +/// Params: i64 (offset) +/// Returns: i64 (length or 0) pub(crate) fn length( mut caller: Caller, input: &[Val], @@ -366,7 +420,7 @@ pub(crate) fn length( output[0] = Val::I64(0); return Ok(()); } - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Unable to find length for offset")), }; @@ -374,7 +428,7 @@ pub(crate) fn length( Ok(()) } -pub(crate) fn log( +pub fn log( level: log::Level, caller: Caller, input: &[Val], @@ -383,11 +437,11 @@ pub(crate) fn log( let data: &Internal = caller.data(); let offset = input[0].unwrap_i64() as usize; - let length = match memory!(data).block_length(offset) { + let length = match data.memory().block_length(offset) { Some(x) => x, None => return Err(Trap::new("Invalid offset in call to http_request")), }; - let buf = memory!(data).get((offset, length)); + let buf = data.memory().get((offset, length)); match std::str::from_utf8(buf) { Ok(buf) => log::log!(level, "{}", buf), @@ -396,6 +450,9 @@ pub(crate) fn log( Ok(()) } +/// Write to logs (warning) +/// Params: i64 (offset) +/// Returns: none pub(crate) fn log_warn( caller: Caller, input: &[Val], @@ -404,6 +461,9 @@ pub(crate) fn log_warn( log(log::Level::Warn, caller, input, _output) } +/// Write to logs (info) +/// Params: i64 (offset) +/// Returns: none pub(crate) fn log_info( caller: Caller, input: &[Val], @@ -412,6 +472,9 @@ pub(crate) fn log_info( log(log::Level::Info, caller, input, _output) } +/// Write to logs (debug) +/// Params: i64 (offset) +/// Returns: none pub(crate) fn log_debug( caller: Caller, input: &[Val], @@ -420,6 +483,9 @@ pub(crate) fn log_debug( log(log::Level::Debug, caller, input, _output) } +/// Write to logs (error) +/// Params: i64 (offset) +/// Returns: none pub(crate) fn log_error( caller: Caller, input: &[Val], diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 64e07f7..f6a5a09 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -1,6 +1,7 @@ pub use anyhow::Error; pub(crate) use wasmtime::*; +mod context; pub(crate) mod export; pub mod manifest; mod memory; @@ -8,12 +9,25 @@ mod plugin; mod plugin_ref; pub mod sdk; +pub use context::Context; pub use manifest::Manifest; pub use memory::{MemoryBlock, PluginMemory}; -pub use plugin::{Internal, Plugin, PLUGINS}; +pub use plugin::{Internal, Plugin}; pub use plugin_ref::PluginRef; pub type Size = u64; pub type PluginIndex = i32; pub(crate) use log::{debug, error, info, trace}; + +/// Converts any type implementing `std::fmt::Debug` into a suitable CString to use +/// as an error message +pub(crate) fn error_string(e: impl std::fmt::Debug) -> std::ffi::CString { + let x = format!("{:?}", e).into_bytes(); + let x = if x[0] == b'"' && x[x.len() - 1] == b'"' { + x[1..x.len() - 1].to_vec() + } else { + x + }; + unsafe { std::ffi::CString::from_vec_unchecked(x) } +} diff --git a/runtime/src/manifest.rs b/runtime/src/manifest.rs index 61684e3..817a2ad 100644 --- a/runtime/src/manifest.rs +++ b/runtime/src/manifest.rs @@ -6,6 +6,7 @@ use sha2::Digest; use crate::*; +/// Manifest wraps the manifest exported by `extism_manifest` #[derive(Default, serde::Serialize, serde::Deserialize)] #[serde(transparent)] pub struct Manifest(extism_manifest::Manifest); @@ -59,11 +60,7 @@ fn check_hash(hash: &Option, data: &[u8]) -> Result<(), Error> { } } -fn hash_url(url: &str) -> String { - let digest = sha2::Sha256::digest(url.as_bytes()); - hex(&digest) -} - +/// Convert from manifest to a wasmtime Module fn to_module( engine: &Engine, wasm: &extism_manifest::ManifestWasm, @@ -74,6 +71,7 @@ fn to_module( return Err(anyhow::format_err!("File-based registration is disabled")); } + // Figure out a good name for the file let name = match name { None => { let name = path.with_extension(""); @@ -82,6 +80,7 @@ fn to_module( Some(n) => n.clone(), }; + // Load file let mut buf = Vec::new(); let mut file = std::fs::File::open(path)?; file.read_to_end(&mut buf)?; @@ -108,6 +107,7 @@ fn to_module( }, hash, } => { + // Get the file name let file_name = url.split('/').last().unwrap(); let name = match name { Some(name) => name.as_str(), @@ -124,7 +124,6 @@ fn to_module( } }; - let url_hash = hash_url(url); if let Some(h) = hash { if let Ok(Some(data)) = cache_get_file(h) { check_hash(hash, &data)?; @@ -140,7 +139,7 @@ fn to_module( #[cfg(feature = "register-http")] { - let url_hash = hash_url(url); + // Setup request let mut req = ureq::request(method.as_deref().unwrap_or("GET"), url); for (k, v) in header.iter() { @@ -152,6 +151,7 @@ fn to_module( let mut data = Vec::new(); r.read_to_end(&mut data)?; + // Try to cache file if let Some(hash) = hash { cache_add_file(hash, &data); } @@ -169,6 +169,7 @@ fn to_module( const WASM_MAGIC: [u8; 4] = [0x00, 0x61, 0x73, 0x6d]; impl Manifest { + /// Create a new Manifest, returns the manifest and a map of modules pub fn new(engine: &Engine, data: &[u8]) -> Result<(Self, BTreeMap), Error> { let has_magic = data.len() >= 4 && data[0..4] == WASM_MAGIC; let is_wast = data.starts_with(b"(module") || data.starts_with(b";;"); @@ -196,6 +197,8 @@ impl Manifest { } let mut modules = BTreeMap::new(); + + // If there's only one module, it should be called `main` if self.0.wasm.len() == 1 { let (_, m) = to_module(engine, &self.0.wasm[0])?; modules.insert("main".to_string(), m); diff --git a/runtime/src/memory.rs b/runtime/src/memory.rs index 096976b..d695cf6 100644 --- a/runtime/src/memory.rs +++ b/runtime/src/memory.rs @@ -20,6 +20,7 @@ const PAGE_SIZE: u32 = 65536; const BLOCK_SIZE_THRESHOLD: usize = 32; impl PluginMemory { + /// Create memory for a plugin pub fn new(store: Store, memory: Memory) -> Self { PluginMemory { free: Vec::new(), @@ -30,6 +31,7 @@ impl PluginMemory { } } + /// Write byte to memory pub(crate) fn store_u8(&mut self, offs: usize, data: u8) -> Result<(), MemoryAccessError> { trace!("store_u8: {data:x} at offset {offs}"); if offs >= self.size() { @@ -42,7 +44,7 @@ impl PluginMemory { Ok(()) } - /// Read from memory + /// Read byte from memory pub(crate) fn load_u8(&self, offs: usize) -> Result { trace!("load_u8: offset {offs}"); if offs >= self.size() { @@ -54,17 +56,18 @@ impl PluginMemory { Ok(self.memory.data(&self.store)[offs]) } + /// Write u32 to memory pub(crate) fn store_u32(&mut self, offs: usize, data: u32) -> Result<(), MemoryAccessError> { trace!("store_u32: {data:x} at offset {offs}"); let handle = MemoryBlock { offset: offs, length: 4, }; - self.write(handle, &data.to_ne_bytes())?; + self.write(handle, data.to_ne_bytes())?; Ok(()) } - /// Read from memory + /// Read u32 from memory pub(crate) fn load_u32(&self, offs: usize) -> Result { trace!("load_u32: offset {offs}"); let mut buf = [0; 4]; @@ -77,16 +80,18 @@ impl PluginMemory { Ok(u32::from_ne_bytes(buf)) } + /// Write u64 to memory pub(crate) fn store_u64(&mut self, offs: usize, data: u64) -> Result<(), MemoryAccessError> { trace!("store_u64: {data:x} at offset {offs}"); let handle = MemoryBlock { offset: offs, length: 8, }; - self.write(handle, &data.to_ne_bytes())?; + self.write(handle, data.to_ne_bytes())?; Ok(()) } + /// Read u64 from memory pub(crate) fn load_u64(&self, offs: usize) -> Result { trace!("load_u64: offset {offs}"); let mut buf = [0; 8]; @@ -98,7 +103,7 @@ impl PluginMemory { Ok(u64::from_ne_bytes(buf)) } - /// Write to memory + /// Write slice to memory pub fn write( &mut self, pos: impl Into, @@ -110,7 +115,7 @@ impl PluginMemory { .write(&mut self.store, pos.offset, data.as_ref()) } - /// Read from memory + /// Read slice from memory pub fn read( &self, pos: impl Into, @@ -232,13 +237,14 @@ impl PluginMemory { } } + /// Log entire memory as hexdump using the `trace` log level pub fn dump(&self) { let data = self.memory.data(&self.store); trace!("{:?}", data[..self.position].hex_dump()); } - /// Reset memory + /// Reset memory - clears free-list and live blocks and resets position pub fn reset(&mut self) { self.free.clear(); self.live_blocks.clear(); @@ -267,6 +273,12 @@ impl PluginMemory { &mut self.memory.data_mut(&mut self.store)[handle.offset..handle.offset + handle.length] } + /// Pointer to the provided memory handle + pub fn ptr(&self, handle: impl Into) -> *mut u8 { + let handle = handle.into(); + unsafe { self.memory.data_ptr(&self.store).add(handle.offset) } + } + /// Get the length of the block starting at `offs` pub fn block_length(&self, offs: usize) -> Option { self.live_blocks.get(&offs).cloned() diff --git a/runtime/src/plugin.rs b/runtime/src/plugin.rs index 28561e4..6a6b196 100644 --- a/runtime/src/plugin.rs +++ b/runtime/src/plugin.rs @@ -38,6 +38,22 @@ impl Internal { plugin: std::ptr::null_mut(), }) } + + pub fn plugin(&self) -> &Plugin { + unsafe { &*self.plugin } + } + + pub fn plugin_mut(&mut self) -> &mut Plugin { + unsafe { &mut *self.plugin } + } + + pub fn memory(&self) -> &PluginMemory { + &self.plugin().memory + } + + pub fn memory_mut(&mut self) -> &mut PluginMemory { + &mut self.plugin_mut().memory + } } const EXPORT_MODULE_NAME: &str = "env"; @@ -145,14 +161,7 @@ impl Plugin { /// Set `last_error` field pub fn set_error(&mut self, e: impl std::fmt::Debug) { debug!("Set error: {:?}", e); - let x = format!("{:?}", e).into_bytes(); - let x = if x[0] == b'"' && x[x.len() - 1] == b'"' { - x[1..x.len() - 1].to_vec() - } else { - x - }; - let e = unsafe { std::ffi::CString::from_vec_unchecked(x) }; - self.last_error = Some(e); + self.last_error = Some(error_string(e)); } pub fn error(&mut self, e: impl std::fmt::Debug, x: E) -> E { @@ -181,6 +190,3 @@ impl Plugin { self.memory.dump(); } } - -/// A registry for plugins -pub static mut PLUGINS: std::sync::Mutex> = std::sync::Mutex::new(Vec::new()); diff --git a/runtime/src/plugin_ref.rs b/runtime/src/plugin_ref.rs index 7e0e031..181fdb5 100644 --- a/runtime/src/plugin_ref.rs +++ b/runtime/src/plugin_ref.rs @@ -1,48 +1,50 @@ use crate::*; -// PluginRef is used to access a plugin from the global plugin registry +// PluginRef is used to access a plugin from a context-scoped plugin registry pub struct PluginRef<'a> { pub id: PluginIndex, - pub plugins: std::sync::MutexGuard<'a, Vec>, - plugin: *mut Plugin, + plugin: &'a mut Plugin, } impl<'a> PluginRef<'a> { - pub fn init(mut self) -> Self { - trace!( - "Resetting memory and clearing error message for plugin {}", - self.id, - ); - // Initialize - self.as_mut().clear_error(); + /// Initialize the plugin for a new call + /// + /// - Resets memory offsets + /// - Updates `input` pointer + pub fn init(mut self, data: *const u8, data_len: usize) -> Self { + trace!("PluginRef::init: {}", self.id,); self.as_mut().memory.reset(); - let internal = self.as_mut().memory.store.data_mut(); - internal.input = std::ptr::null(); - internal.input_length = 0; + self.plugin.set_input(data, data_len); self } - /// # Safety - /// - /// This function is used to access the static `PLUGINS` registry - pub unsafe fn new(plugin_id: PluginIndex) -> Self { - let mut plugins = match PLUGINS.lock() { - Ok(p) => p, - Err(e) => e.into_inner(), - }; - + /// Create a `PluginRef` from a context + pub fn new(ctx: &'a mut Context, plugin_id: PluginIndex, clear_error: bool) -> Self { trace!("Loading plugin {plugin_id}"); - if plugin_id < 0 || plugin_id as usize >= plugins.len() { - drop(plugins); - panic!("Invalid PluginIndex {plugin_id}"); + if plugin_id < 0 { + panic!("Invalid PluginIndex in PluginRef::new: {plugin_id}"); } - let plugin = plugins.get_unchecked_mut(plugin_id as usize) as *mut _; + if clear_error { + ctx.error = None; + } + + let plugin = ctx.plugin(plugin_id); + + let plugin = match plugin { + None => { + panic!("Plugin does not exist: {plugin_id}"); + } + Some(p) => p, + }; + + if clear_error { + plugin.clear_error(); + } PluginRef { id: plugin_id, - plugins, plugin, } } @@ -50,13 +52,13 @@ impl<'a> PluginRef<'a> { impl<'a> AsRef for PluginRef<'a> { fn as_ref(&self) -> &Plugin { - unsafe { &*self.plugin } + self.plugin } } impl<'a> AsMut for PluginRef<'a> { fn as_mut(&mut self) -> &mut Plugin { - unsafe { &mut *self.plugin } + self.plugin } } diff --git a/runtime/src/sdk.rs b/runtime/src/sdk.rs index e6faa68..598995c 100644 --- a/runtime/src/sdk.rs +++ b/runtime/src/sdk.rs @@ -5,74 +5,135 @@ use std::str::FromStr; use crate::*; +/// Create a new context #[no_mangle] -pub unsafe extern "C" fn extism_plugin_register( +pub unsafe extern "C" fn extism_context_new() -> *mut Context { + trace!("Creating new Context"); + Box::into_raw(Box::new(Context::new())) +} + +/// Free a context +#[no_mangle] +pub unsafe extern "C" fn extism_context_free(ctx: *mut Context) { + trace!("Freeing context"); + if ctx.is_null() { + return; + } + drop(Box::from_raw(ctx)) +} + +/// Create a new plugin +/// +/// `wasm`: is a WASM module (wat or wasm) or a JSON encoded manifest +/// `wasm_size`: the length of the `wasm` parameter +/// `with_wasi`: enables/disables WASI +#[no_mangle] +pub unsafe extern "C" fn extism_plugin_new( + ctx: *mut Context, wasm: *const u8, wasm_size: Size, with_wasi: bool, ) -> PluginIndex { - trace!( - "Call to extism_plugin_register with wasm pointer {:?}", - wasm - ); + trace!("Call to extism_plugin_new with wasm pointer {:?}", wasm); + let ctx = &mut *ctx; + let data = std::slice::from_raw_parts(wasm, wasm_size as usize); let plugin = match Plugin::new(data, with_wasi) { Ok(x) => x, Err(e) => { error!("Error creating Plugin: {:?}", e); + ctx.set_error(e); return -1; } }; - let mut plugins = match PLUGINS.lock() { - Ok(p) => p, - Err(e) => e.into_inner(), + // Allocate a new plugin ID + let id: i32 = match ctx.next_id() { + Ok(id) => id, + Err(e) => { + error!("Error creating Plugin: {:?}", e); + ctx.set_error(e); + return -1; + } }; - - plugins.push(plugin); - let id = (plugins.len() - 1) as PluginIndex; + ctx.plugins.insert(id, plugin); info!("New plugin added: {id}"); id } +/// Update a plugin, keeping the existing ID +/// +/// Similar to `extism_plugin_new` but takes an `index` argument to specify +/// which plugin to update +/// +/// Memory for this plugin will be reset upon update #[no_mangle] pub unsafe extern "C" fn extism_plugin_update( + ctx: *mut Context, index: PluginIndex, wasm: *const u8, wasm_size: Size, with_wasi: bool, ) -> bool { - let index = index as usize; trace!("Call to extism_plugin_update with wasm pointer {:?}", wasm); + let ctx = &mut *ctx; let data = std::slice::from_raw_parts(wasm, wasm_size as usize); let plugin = match Plugin::new(data, with_wasi) { Ok(x) => x, Err(e) => { error!("Error creating Plugin: {:?}", e); + ctx.set_error(e); return false; } }; - let mut plugins = match PLUGINS.lock() { - Ok(p) => p, - Err(e) => e.into_inner(), - }; - - if index < plugins.len() { - plugins[index] = plugin; + if !ctx.plugins.contains_key(&index) { + ctx.set_error("Plugin index does not exist"); + return false; } + ctx.plugins.insert(index, plugin); + info!("Plugin updated: {index}"); true } +/// Remove a plugin from the registry and free associated memory +#[no_mangle] +pub unsafe extern "C" fn extism_plugin_free(ctx: *mut Context, plugin: PluginIndex) { + if plugin < 0 || ctx.is_null() { + return; + } + + trace!("Freeing plugin {plugin}"); + + let ctx = &mut *ctx; + ctx.remove(plugin); +} + +/// Remove all plugins from the registry +#[no_mangle] +pub unsafe extern "C" fn extism_context_reset(ctx: *mut Context) { + let ctx = &mut *ctx; + + trace!( + "Resetting context, plugins cleared: {:?}", + ctx.plugins.keys().collect::>() + ); + + ctx.plugins.clear(); +} + +/// Update plugin config values, this will merge with the existing values #[no_mangle] pub unsafe extern "C" fn extism_plugin_config( + ctx: *mut Context, plugin: PluginIndex, json: *const u8, json_size: Size, ) -> bool { - let mut plugin = PluginRef::new(plugin); + let ctx = &mut *ctx; + let mut plugin = PluginRef::new(ctx, plugin, true); trace!( "Call to extism_plugin_config for {} with json pointer {:?}", @@ -81,50 +142,74 @@ pub unsafe extern "C" fn extism_plugin_config( ); let data = std::slice::from_raw_parts(json, json_size as usize); - let json: std::collections::BTreeMap = match serde_json::from_slice(data) { - Ok(x) => x, - Err(e) => { - plugin.as_mut().set_error(e); - return false; - } - }; + let json: std::collections::BTreeMap> = + match serde_json::from_slice(data) { + Ok(x) => x, + Err(e) => { + return plugin.as_mut().error(e, false); + } + }; let plugin = plugin.as_mut(); let wasi = &mut plugin.memory.store.data_mut().wasi; let config = &mut plugin.manifest.as_mut().config; for (k, v) in json.into_iter() { - trace!("Config, adding {k}"); - let _ = wasi.push_env(&k, &v); - config.insert(k, v); + match v { + Some(v) => { + trace!("Config, adding {k}"); + let _ = wasi.push_env(&k, &v); + config.insert(k, v); + } + None => { + config.remove(&k); + } + } } true } +/// Returns true if `func_name` exists #[no_mangle] -pub unsafe extern "C" fn extism_function_exists( +pub unsafe extern "C" fn extism_plugin_function_exists( + ctx: *mut Context, plugin: PluginIndex, func_name: *const c_char, ) -> bool { - let mut plugin = PluginRef::new(plugin); + let ctx = &mut *ctx; + let mut plugin = PluginRef::new(ctx, plugin, true); let name = std::ffi::CStr::from_ptr(func_name); + trace!("Call to extism_plugin_function_exists for: {:?}", name); + let name = match name.to_str() { Ok(x) => x, - Err(_) => return false, + Err(e) => { + return plugin.as_mut().error(e, false); + } }; plugin.as_mut().get_func(name).is_some() } +/// Call a function +/// +/// `func_name`: is the function to call +/// `data`: is the input data +/// `data_len`: is the length of `data` #[no_mangle] -pub unsafe extern "C" fn extism_call( +pub unsafe extern "C" fn extism_plugin_call( + ctx: *mut Context, plugin_id: PluginIndex, func_name: *const c_char, data: *const u8, data_len: Size, ) -> i32 { - let mut plugin = PluginRef::new(plugin_id).init(); + let ctx = &mut *ctx; + + // Get a `PluginRef` and call `init` to set up the plugin input and memory, this is only + // needed before a new call + let mut plugin = PluginRef::new(ctx, plugin_id, true).init(data, data_len as usize); let plugin = plugin.as_mut(); // Find function @@ -141,9 +226,6 @@ pub unsafe extern "C" fn extism_call( None => return plugin.error(format!("Function not found: {name}"), -1), }; - // Always needs to be called before `func.call()` - plugin.set_input(data, data_len as usize); - // Call function with offset+length pointing to input data. let mut results = vec![Val::I32(0)]; match func.call(&mut plugin.memory.store, &[], results.as_mut_slice()) { @@ -161,10 +243,25 @@ pub unsafe extern "C" fn extism_call( results[0].unwrap_i32() } +/// Get the error associated with a `Context` or `Plugin`, if `plugin` is `-1` then the context +/// error will be returned #[no_mangle] -pub unsafe extern "C" fn extism_error(plugin: PluginIndex) -> *const c_char { +pub unsafe extern "C" fn extism_error(ctx: *mut Context, plugin: PluginIndex) -> *const c_char { trace!("Call to extism_error for plugin {plugin}"); - let plugin = PluginRef::new(plugin); + + let ctx = &mut *ctx; + + if plugin < 0 { + match ctx.error.as_ref() { + Some(e) => return e.as_ptr() as *const _, + None => { + trace!("Error is NULL"); + return std::ptr::null(); + } + } + } + + let plugin = PluginRef::new(ctx, plugin, false); match &plugin.as_ref().last_error { Some(e) => e.as_ptr() as *const _, None => { @@ -174,21 +271,32 @@ pub unsafe extern "C" fn extism_error(plugin: PluginIndex) -> *const c_char { } } +/// Get the length of a plugin's output data #[no_mangle] -pub unsafe extern "C" fn extism_output_length(plugin: PluginIndex) -> Size { - trace!("Call to extism_output_length for plugin {plugin}"); - let plugin = PluginRef::new(plugin); +pub unsafe extern "C" fn extism_plugin_output_length( + ctx: *mut Context, + plugin: PluginIndex, +) -> Size { + trace!("Call to extism_plugin_output_length for plugin {plugin}"); + + let ctx = &mut *ctx; + let plugin = PluginRef::new(ctx, plugin, true); let len = plugin.as_ref().memory.store.data().output_length as Size; trace!("Output length: {len}"); len } +/// Get the length of a plugin's output data #[no_mangle] -pub unsafe extern "C" fn extism_output_get(plugin: PluginIndex) -> *const u8 { - trace!("Call to extism_output_get for plugin {plugin}"); +pub unsafe extern "C" fn extism_plugin_output_data( + ctx: *mut Context, + plugin: PluginIndex, +) -> *const u8 { + trace!("Call to extism_plugin_output_data for plugin {plugin}"); - let plugin = PluginRef::new(plugin); + let ctx = &mut *ctx; + let plugin = PluginRef::new(ctx, plugin, true); let data = plugin.as_ref().memory.store.data(); plugin @@ -198,6 +306,7 @@ pub unsafe extern "C" fn extism_output_get(plugin: PluginIndex) -> *const u8 { .as_ptr() } +/// Set log file and level #[no_mangle] pub unsafe extern "C" fn extism_log_file( filename: *const c_char, @@ -254,8 +363,7 @@ pub unsafe extern "C" fn extism_log_file( } else { match FileAppender::builder().encoder(encoder).build(file) { Ok(x) => Box::new(x), - Err(e) => { - error!("Unable to set up log encoder: {e:?}"); + Err(_) => { return false; } } diff --git a/rust/build.rs b/rust/build.rs index 6436fba..bd7bfbf 100644 --- a/rust/build.rs +++ b/rust/build.rs @@ -1,20 +1,10 @@ fn main() { - let out_dir = std::env::var("OUT_DIR").unwrap(); - if std::path::PathBuf::from("libextism.so").exists() { - std::process::Command::new("cp") - .arg("libextism.so") - .arg(&out_dir) - .status() - .unwrap(); - } else { - std::process::Command::new("cp") - .arg("libextism.dylib") - .arg(&out_dir) - .status() - .unwrap(); + println!("cargo:rustc-link-search=/usr/local/lib"); + + if let Ok(home) = std::env::var("HOME") { + let path = std::path::PathBuf::from(home).join(".local").join("lib"); + println!("cargo:rustc-link-search={}", path.display()); } - println!("cargo:rustc-link-search={}", out_dir); + println!("cargo:rustc-link-lib=extism"); - println!("cargo:rerun-if-changed=libextism.so"); - println!("cargo:rerun-if-changed=libextism.dylib"); } diff --git a/rust/src/bindings.rs b/rust/src/bindings.rs index 24a03dc..4dbe496 100644 --- a/rust/src/bindings.rs +++ b/rust/src/bindings.rs @@ -3,10 +3,22 @@ pub type __uint8_t = ::std::os::raw::c_uchar; pub type __int32_t = ::std::os::raw::c_int; pub type __uint64_t = ::std::os::raw::c_ulong; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ExtismContext { + _unused: [u8; 0], +} pub type ExtismPlugin = i32; pub type ExtismSize = u64; extern "C" { - pub fn extism_plugin_register( + pub fn extism_context_new() -> *mut ExtismContext; +} +extern "C" { + pub fn extism_context_free(ctx: *mut ExtismContext); +} +extern "C" { + pub fn extism_plugin_new( + ctx: *mut ExtismContext, wasm: *const u8, wasm_size: ExtismSize, with_wasi: bool, @@ -14,27 +26,37 @@ extern "C" { } extern "C" { pub fn extism_plugin_update( + ctx: *mut ExtismContext, index: ExtismPlugin, wasm: *const u8, wasm_size: ExtismSize, with_wasi: bool, ) -> bool; } +extern "C" { + pub fn extism_plugin_free(ctx: *mut ExtismContext, plugin: ExtismPlugin); +} +extern "C" { + pub fn extism_context_reset(ctx: *mut ExtismContext); +} extern "C" { pub fn extism_plugin_config( + ctx: *mut ExtismContext, plugin: ExtismPlugin, json: *const u8, json_size: ExtismSize, ) -> bool; } extern "C" { - pub fn extism_function_exists( + pub fn extism_plugin_function_exists( + ctx: *mut ExtismContext, plugin: ExtismPlugin, func_name: *const ::std::os::raw::c_char, ) -> bool; } extern "C" { - pub fn extism_call( + pub fn extism_plugin_call( + ctx: *mut ExtismContext, plugin_id: ExtismPlugin, func_name: *const ::std::os::raw::c_char, data: *const u8, @@ -42,13 +64,17 @@ extern "C" { ) -> i32; } extern "C" { - pub fn extism_error(plugin: ExtismPlugin) -> *const ::std::os::raw::c_char; + pub fn extism_error( + ctx: *mut ExtismContext, + plugin: ExtismPlugin, + ) -> *const ::std::os::raw::c_char; } extern "C" { - pub fn extism_output_length(plugin: ExtismPlugin) -> ExtismSize; + pub fn extism_plugin_output_length(ctx: *mut ExtismContext, plugin: ExtismPlugin) + -> ExtismSize; } extern "C" { - pub fn extism_output_get(plugin: ExtismPlugin) -> *const u8; + pub fn extism_plugin_output_data(ctx: *mut ExtismContext, plugin: ExtismPlugin) -> *const u8; } extern "C" { pub fn extism_log_file( diff --git a/rust/src/context.rs b/rust/src/context.rs new file mode 100644 index 0000000..c25eec5 --- /dev/null +++ b/rust/src/context.rs @@ -0,0 +1,34 @@ +use crate::*; + +pub struct Context { + pub(crate) pointer: *mut bindings::ExtismContext, +} + +impl Default for Context { + fn default() -> Context { + Context::new() + } +} + +impl Context { + /// Create a new context + pub fn new() -> Context { + let pointer = unsafe { bindings::extism_context_new() }; + Context { pointer } + } + + /// Remove all registered plugins + pub fn reset(&mut self) { + unsafe { bindings::extism_context_reset(self.pointer) } + } +} + +impl Drop for Context { + fn drop(&mut self) { + if self.pointer.is_null() { + return; + } + unsafe { bindings::extism_context_free(self.pointer) } + self.pointer = std::ptr::null_mut(); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index df00eeb..aebe13d 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -3,14 +3,17 @@ use std::collections::BTreeMap; use extism_manifest::Manifest; #[allow(non_camel_case_types)] -mod bindings; +pub(crate) mod bindings; -#[repr(transparent)] -pub struct Plugin(isize); +mod context; +mod plugin; + +pub use context::Context; +pub use plugin::Plugin; #[derive(Debug)] pub enum Error { - UnableToLoadPlugin, + UnableToLoadPlugin(String), Message(String), Json(serde_json::Error), } @@ -21,115 +24,14 @@ impl From for Error { } } -impl Plugin { - pub fn new_with_manifest(manifest: &Manifest, wasi: bool) -> Result { - let data = serde_json::to_vec(manifest)?; - Self::new(data, wasi) - } - - pub fn new(data: impl AsRef<[u8]>, wasi: bool) -> Result { - let plugin = unsafe { - bindings::extism_plugin_register( - data.as_ref().as_ptr(), - data.as_ref().len() as u64, - wasi, - ) - }; - - if plugin < 0 { - return Err(Error::UnableToLoadPlugin); - } - - Ok(Plugin(plugin as isize)) - } - - pub fn update(&mut self, data: impl AsRef<[u8]>, wasi: bool) -> bool { - unsafe { - bindings::extism_plugin_update( - self.0 as i32, - data.as_ref().as_ptr(), - data.as_ref().len() as u64, - wasi, - ) - } - } - - pub fn update_manifest(&mut self, manifest: &Manifest, wasi: bool) -> Result { - let data = serde_json::to_vec(manifest)?; - Ok(self.update(data, wasi)) - } - - pub fn set_config(&self, config: &BTreeMap) -> Result<(), Error> { - let encoded = serde_json::to_vec(config)?; - unsafe { - bindings::extism_plugin_config( - self.0 as i32, - encoded.as_ptr() as *const _, - encoded.len() as u64, - ) - }; - Ok(()) - } - - pub fn with_config(self, config: &BTreeMap) -> Result { - self.set_config(config)?; - Ok(self) - } - - pub fn set_log_file( - &self, - filename: impl AsRef, - log_level: Option, - ) { - let log_level = log_level.map(|x| x.as_str()); - unsafe { - bindings::extism_log_file( - filename.as_ref().as_os_str().to_string_lossy().as_ptr() as *const _, - log_level.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *const _, - ); - } - } - - pub fn with_log_file( - self, - filename: impl AsRef, - log_level: Option, - ) -> Self { - self.set_log_file(filename, log_level); - self - } - - pub fn has_function(&self, name: impl AsRef) -> bool { - let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); - unsafe { bindings::extism_function_exists(self.0 as i32, name.as_ptr() as *const _) } - } - - pub fn call(&self, name: impl AsRef, input: impl AsRef<[u8]>) -> Result<&[u8], Error> { - let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); - let rc = unsafe { - bindings::extism_call( - self.0 as i32, - name.as_ptr() as *const _, - input.as_ref().as_ptr() as *const _, - input.as_ref().len() as u64, - ) - }; - - if rc != 0 { - let err = unsafe { bindings::extism_error(self.0 as i32) }; - if !err.is_null() { - let s = unsafe { std::ffi::CStr::from_ptr(err) }; - return Err(Error::Message(s.to_str().unwrap().to_string())); - } - - return Err(Error::Message("extism_call failed".to_string())); - } - - let out_len = unsafe { bindings::extism_output_length(self.0 as i32) }; - unsafe { - let ptr = bindings::extism_output_get(self.0 as i32); - Ok(std::slice::from_raw_parts(ptr, out_len as usize)) - } +/// Set the log file and level, this is a global setting +pub fn set_log_file(filename: impl AsRef, log_level: Option) { + let log_level = log_level.map(|x| x.as_str()); + unsafe { + bindings::extism_log_file( + filename.as_ref().as_os_str().to_string_lossy().as_ptr() as *const _, + log_level.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *const _, + ); } } @@ -138,13 +40,14 @@ mod tests { use super::*; use std::time::Instant; + const WASM: &[u8] = include_bytes!("../../wasm/code.wasm"); + #[test] fn it_works() { - let wasm = include_bytes!("../../wasm/code.wasm"); let wasm_start = Instant::now(); - let plugin = Plugin::new(wasm, false) - .unwrap() - .with_log_file("test.log", Some(log::LevelFilter::Info)); + set_log_file("test.log", Some(log::Level::Info)); + let context = Context::new(); + let plugin = Plugin::new(&context, WASM, false).unwrap(); println!("register loaded plugin: {:?}", wasm_start.elapsed()); let repeat = 1182; @@ -230,4 +133,27 @@ mod tests { println!("wasm function call (avg, N = {}): {:?}", num_tests, avg); } + + #[test] + fn test_threads() { + use std::io::Write; + std::thread::spawn(|| { + let context = Context::new(); + let plugin = Plugin::new(&context, WASM, false).unwrap(); + let output = plugin.call("count_vowels", "this is a test").unwrap(); + std::io::stdout().write_all(output).unwrap(); + }); + + std::thread::spawn(|| { + let context = Context::new(); + let plugin = Plugin::new(&context, WASM, false).unwrap(); + let output = plugin.call("count_vowels", "this is a test aaa").unwrap(); + std::io::stdout().write_all(output).unwrap(); + }); + + let context = Context::new(); + let plugin = Plugin::new(&context, WASM, false).unwrap(); + let output = plugin.call("count_vowels", "abc123").unwrap(); + std::io::stdout().write_all(output).unwrap(); + } } diff --git a/rust/src/plugin.rs b/rust/src/plugin.rs new file mode 100644 index 0000000..aa69a3c --- /dev/null +++ b/rust/src/plugin.rs @@ -0,0 +1,144 @@ +use crate::*; + +pub struct Plugin<'a> { + id: bindings::ExtismPlugin, + context: &'a Context, +} + +impl<'a> Plugin<'a> { + /// Create a new plugin from the given manifest + pub fn new_with_manifest( + ctx: &'a Context, + manifest: &Manifest, + wasi: bool, + ) -> Result, Error> { + let data = serde_json::to_vec(manifest)?; + Self::new(ctx, data, wasi) + } + + /// Create a new plugin from a WASM module + pub fn new(ctx: &'a Context, data: impl AsRef<[u8]>, wasi: bool) -> Result { + let plugin = unsafe { + bindings::extism_plugin_new( + ctx.pointer, + data.as_ref().as_ptr(), + data.as_ref().len() as u64, + wasi, + ) + }; + + if plugin < 0 { + let err = unsafe { bindings::extism_error(ctx.pointer, -1) }; + let buf = unsafe { std::ffi::CStr::from_ptr(err) }; + let buf = buf.to_str().unwrap().to_string(); + return Err(Error::UnableToLoadPlugin(buf)); + } + + Ok(Plugin { + id: plugin, + context: ctx, + }) + } + + /// Update a plugin with the given WASM module + pub fn update(&mut self, data: impl AsRef<[u8]>, wasi: bool) -> Result<(), Error> { + let b = unsafe { + bindings::extism_plugin_update( + self.context.pointer, + self.id, + data.as_ref().as_ptr(), + data.as_ref().len() as u64, + wasi, + ) + }; + if b { + return Ok(()); + } + + let err = unsafe { bindings::extism_error(self.context.pointer, -1) }; + if !err.is_null() { + let s = unsafe { std::ffi::CStr::from_ptr(err) }; + return Err(Error::Message(s.to_str().unwrap().to_string())); + } + + Err(Error::Message("extism_plugin_update failed".to_string())) + } + + /// Update a plugin with the given manifest + pub fn update_manifest(&mut self, manifest: &Manifest, wasi: bool) -> Result<(), Error> { + let data = serde_json::to_vec(manifest)?; + self.update(data, wasi) + } + + /// Set configuration values + pub fn set_config(&self, config: &BTreeMap>) -> Result<(), Error> { + let encoded = serde_json::to_vec(config)?; + unsafe { + bindings::extism_plugin_config( + self.context.pointer, + self.id, + encoded.as_ptr() as *const _, + encoded.len() as u64, + ) + }; + Ok(()) + } + + /// Set configuration values, builder-style + pub fn with_config(self, config: &BTreeMap>) -> Result { + self.set_config(config)?; + Ok(self) + } + + /// Returns true if the plugin has a function matching `name` + pub fn has_function(&self, name: impl AsRef) -> bool { + let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); + unsafe { + bindings::extism_plugin_function_exists( + self.context.pointer, + self.id, + name.as_ptr() as *const _, + ) + } + } + + /// Call a function with the given input + pub fn call(&self, name: impl AsRef, input: impl AsRef<[u8]>) -> Result<&[u8], Error> { + let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); + let rc = unsafe { + bindings::extism_plugin_call( + self.context.pointer, + self.id, + name.as_ptr() as *const _, + input.as_ref().as_ptr() as *const _, + input.as_ref().len() as u64, + ) + }; + + if rc != 0 { + let err = unsafe { bindings::extism_error(self.context.pointer, self.id) }; + if !err.is_null() { + let s = unsafe { std::ffi::CStr::from_ptr(err) }; + return Err(Error::Message(s.to_str().unwrap().to_string())); + } + + return Err(Error::Message("extism_call failed".to_string())); + } + + let out_len = + unsafe { bindings::extism_plugin_output_length(self.context.pointer, self.id) }; + unsafe { + let ptr = bindings::extism_plugin_output_data(self.context.pointer, self.id); + Ok(std::slice::from_raw_parts(ptr, out_len as usize)) + } + } +} + +impl<'a> Drop for Plugin<'a> { + fn drop(&mut self) { + if self.context.pointer.is_null() { + return; + } + unsafe { bindings::extism_plugin_free(self.context.pointer, self.id) } + } +} diff --git a/scripts/header.py b/scripts/header.py index cdd1828..ef10e0a 100644 --- a/scripts/header.py +++ b/scripts/header.py @@ -1,38 +1,46 @@ from pycparser import c_ast, parse_file + class Function: + def __init__(self, name, return_type, args): self.name = name self.return_type = return_type self.args = args - - + + typemap = { "_Bool": "bool", "ExtismPlugin": "int32_t", "ExtismSize": "uint64_t", } - + + class Type: + def __init__(self, name, const=False, pointer=False): self.name = typemap.get(name) or name self.const = const self.pointer = pointer - + + class Arg: + def __init__(self, name, type): self.name = name - self.type = type + self.type = type + class Visitor(c_ast.NodeVisitor): + def __init__(self, header): self.header = header - + def args(self, args): dest = [] for arg in args: name = arg.name - + if isinstance(arg.type, c_ast.PtrDecl): t = arg.type.type is_ptr = True @@ -40,12 +48,15 @@ class Visitor(c_ast.NodeVisitor): t = arg.type is_ptr = False - type_name = t.type.names[0] + if hasattr(t.type, 'name'): + type_name = t.type.name + else: + type_name = t.type.names[0] const = hasattr(t.type, 'quals') and 'const' in t.type.quals t = Type(type_name, const=const, pointer=is_ptr) dest.append(Arg(name, t)) return dest - + def visit_FuncDecl(self, node: c_ast.FuncDecl): if isinstance(node.type, c_ast.PtrDecl): t = node.type.type @@ -56,21 +67,26 @@ class Visitor(c_ast.NodeVisitor): name = t.declname args = self.args(node.args) - return_type_name = t.type.names[0] + if hasattr(t.type, 'name'): + return_type_name = t.type.name + else: + return_type_name = t.type.names[0] const = hasattr(t.type, 'quals') and 'const' in t.type.quals return_type = Type(return_type_name, const=const, pointer=is_ptr) self.header.functions.append(Function(name, return_type, args)) - + + class Header: + def __init__(self, filename='runtime/extism.h'): self.functions = [] self.header = parse_file(filename, use_cpp=True, cpp_args='-w') self.visitor = Visitor(self) self.visitor.visit(self.header) - + def __iter__(self): return self.functions.__iter__() - + def __getitem__(self, func): for f in self.functions: if f.name == func: