From 0f8954c2039cdffbb53352df6f613e4bbbb74235 Mon Sep 17 00:00:00 2001 From: zach Date: Wed, 17 May 2023 11:35:16 -0700 Subject: [PATCH] feat!: add ability to create plugins without an existing `Context` (#335) EIP: https://github.com/extism/proposals/pull/8 This PR makes minor breaking changes to several SDKs, but not to runtime C API. The threadsafety updates in the Rust SDK are kind of specific to Rust, I'm not sure if it makes sense to add the locks to all the other SDKs at this point. For the most part the `Context` and `Plugin` types in the SDKs should be safe to use protected by a mutex but they aren't inherently threadsafe. That kind of locking should probably be done by the user. - Runtime - improve thread safety - reinstantiates less - fixes a potential resource exhaustion bug from re-instantiating using the same store too many times - Rust SDK - adds `Send` and `Sync` implementations for `Context` - adds test sharing a context between threads - adds `Plugin::call_map` to call a plugin and handle the output with the lock held - adds testing sharing an `Arc>` between threads - adds `Plugin::create` and `Plugin::create_from_manifest` to create a plugin without a `Context` - Python - BREAKING - changes `Plugin` constructor to take `context` as an optional named argument, to update use `Plugin(data, context=context)` instead - Ruby - BREAKING - changes `Plugin` constructor to take `context` as an optional named argument, to update use `Plugin.new(data, context=context)` instead - Go - adds `NewPlugin` and `NewPluginFromManifest` functions - Node - BREAKING - changes `Plugin` constructor to take `context` as an optional named argument, to update use `new Plugin(data, wasi, config, host, context)` instead of `new Plugin(context, data, wasi, functions, config)` (most people are probably using `context.plugin` instead of the Plugin constructor anyway) - OCaml - BREAKING - changes `Plugin.create` and `Plugin.of_manifest` to take `context` as an optional named argument, to update `Plugin.create ~context data` and `Plugin.of_manifest ~context data` instead - Haskell - adds `createPlugin` and `createPluginFromManifest` functions - Elixir - adds `Plugin.new` to make a plugin without going through `Context.new_plugin` - Java - adds new `Plugin` constructors without a `Context` argument - C++ - BREAKING - Updates `Plugin` constructor to take an optional context as the last argument, instead of requiring it to be the first argument - Use `Plugin(wasm, wasi, functions, ctx)` instead of `Plugin(ctx, wasm, wasi, functions)` - Zig - Adds `Plugin.create` and `Plugin.createWithManifest` to create plugins in their own context. --------- Co-authored-by: zach Co-authored-by: Benjamin Eckel --- .github/workflows/ci-rust.yml | 1 - cpp/example.cpp | 3 +- cpp/extism.hpp | 38 ++- cpp/test/test.cpp | 11 +- dotnet/src/Extism.Sdk/Context.cs | 2 +- dotnet/src/Extism.Sdk/Plugin.cs | 14 +- dotnet/test/Extism.Sdk/BasicTests.cs | 13 +- elixir/lib/extism/cancel_handle.ex | 2 +- elixir/lib/extism/plugin.ex | 15 +- elixir/native/extism_nif/src/lib.rs | 4 +- extism.go | 12 + go/main.go | 5 +- haskell/Example.hs | 3 +- haskell/src/Extism.hs | 12 + haskell/test/Test.hs | 40 +-- java/src/main/java/org/extism/sdk/Plugin.java | 10 + node/example.js | 21 +- node/src/index.ts | 75 +++--- ocaml/bin/main.ml | 3 +- ocaml/lib/extism.mli | 4 +- ocaml/lib/plugin.ml | 55 ++-- python/example.py | 34 ++- python/extism/extism.py | 11 +- python/tests/test_extism.py | 2 + ruby/example.rb | 18 +- ruby/lib/extism.rb | 9 +- runtime/build.rs | 2 +- runtime/extism.h | 2 +- runtime/src/context.rs | 7 +- runtime/src/function.rs | 4 +- runtime/src/internal.rs | 132 ++++++++++ runtime/src/lib.rs | 4 +- runtime/src/memory.rs | 93 +++++-- runtime/src/pdk.rs | 51 ++-- runtime/src/plugin.rs | 240 ++++++++---------- runtime/src/plugin_ref.rs | 41 ++- runtime/src/sdk.rs | 139 +++++----- rust/src/context.rs | 4 + rust/src/lib.rs | 82 ++++-- rust/src/plugin.rs | 160 +++++++++--- rust/src/plugin_builder.rs | 28 +- zig/src/plugin.zig | 24 +- 42 files changed, 916 insertions(+), 514 deletions(-) create mode 100644 runtime/src/internal.rs diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index 259743a..4e58c7c 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -87,7 +87,6 @@ jobs: run: cargo clippy --release --all-features --no-deps -p ${{ env.RUNTIME_CRATE }} - name: Test run: cargo test --all-features --release -p ${{ env.RUNTIME_CRATE }} - run: cat test.log rust: name: Rust diff --git a/cpp/example.cpp b/cpp/example.cpp index 64138ea..aecca69 100644 --- a/cpp/example.cpp +++ b/cpp/example.cpp @@ -15,7 +15,6 @@ std::vector read(const char *filename) { int main(int argc, char *argv[]) { auto wasm = read("../wasm/code-functions.wasm"); - Context context = Context(); std::string tmp = "Testing"; // A lambda can be used as a host function @@ -34,7 +33,7 @@ int main(int argc, char *argv[]) { [](void *x) { std::cout << "Free user data" << std::endl; }), }; - Plugin plugin = context.plugin(wasm, true, functions); + Plugin plugin(wasm, true, functions); const char *input = argc > 1 ? argv[1] : "this is a test"; ExtismSize length = strlen(input); diff --git a/cpp/extism.hpp b/cpp/extism.hpp index 30f194e..c149066 100644 --- a/cpp/extism.hpp +++ b/cpp/extism.hpp @@ -342,9 +342,10 @@ class Plugin { public: // Create a new plugin - Plugin(std::shared_ptr ctx, const uint8_t *wasm, - ExtismSize length, bool with_wasi = false, - std::vector functions = std::vector()) + Plugin(const uint8_t *wasm, ExtismSize length, bool with_wasi = false, + std::vector functions = std::vector(), + std::shared_ptr ctx = std::shared_ptr( + extism_context_new(), extism_context_free)) : functions(functions) { std::vector ptrs; for (auto i : this->functions) { @@ -359,6 +360,19 @@ public: this->context = ctx; } + Plugin(const std::string &str, bool with_wasi = false, + std::vector functions = {}, + std::shared_ptr ctx = std::shared_ptr( + extism_context_new(), extism_context_free)) + : Plugin((const uint8_t *)str.c_str(), str.size(), with_wasi, functions, + ctx) {} + + Plugin(const std::vector &data, bool with_wasi = false, + std::vector functions = {}, + std::shared_ptr ctx = std::shared_ptr( + extism_context_new(), extism_context_free)) + : Plugin(data.data(), data.size(), with_wasi, functions, ctx) {} + CancelHandle cancel_handle() { return CancelHandle( extism_plugin_cancel_handle(this->context.get(), this->id())); @@ -366,8 +380,10 @@ public: #ifndef EXTISM_NO_JSON // Create a new plugin from Manifest - Plugin(std::shared_ptr ctx, const Manifest &manifest, - bool with_wasi = false, std::vector functions = {}) { + Plugin(const Manifest &manifest, bool with_wasi = false, + std::vector functions = {}, + std::shared_ptr ctx = std::shared_ptr( + extism_context_new(), extism_context_free)) { std::vector ptrs; for (auto i : this->functions) { ptrs.push_back(i.get()); @@ -506,28 +522,28 @@ public: // Create plugin from uint8_t* Plugin plugin(const uint8_t *wasm, size_t length, bool with_wasi = false, std::vector functions = {}) const { - return Plugin(this->pointer, wasm, length, with_wasi, functions); + return Plugin(wasm, length, with_wasi, functions, this->pointer); } // Create plugin from std::string Plugin plugin(const std::string &str, bool with_wasi = false, std::vector functions = {}) const { - return Plugin(this->pointer, (const uint8_t *)str.c_str(), str.size(), - with_wasi, functions); + return Plugin((const uint8_t *)str.c_str(), str.size(), with_wasi, + functions, this->pointer); } // Create plugin from uint8_t vector Plugin plugin(const std::vector &data, bool with_wasi = false, std::vector functions = {}) const { - return Plugin(this->pointer, data.data(), data.size(), with_wasi, - functions); + return Plugin(data.data(), data.size(), with_wasi, functions, + this->pointer); } #ifndef EXTISM_NO_JSON // Create plugin from Manifest Plugin plugin(const Manifest &manifest, bool with_wasi = false, std::vector functions = {}) const { - return Plugin(this->pointer, manifest, with_wasi, functions); + return Plugin(manifest, with_wasi, functions, this->pointer); } #endif diff --git a/cpp/test/test.cpp b/cpp/test/test.cpp index 25a68c8..61831f9 100644 --- a/cpp/test/test.cpp +++ b/cpp/test/test.cpp @@ -21,21 +21,19 @@ TEST(Context, Basic) { } TEST(Plugin, Manifest) { - Context context; Manifest manifest = Manifest::path(code); manifest.set_config("a", "1"); - ASSERT_NO_THROW(Plugin plugin = context.plugin(manifest)); - Plugin plugin = context.plugin(manifest); + ASSERT_NO_THROW(Plugin plugin(manifest)); + Plugin plugin(manifest); Buffer buf = plugin.call("count_vowels", "this is a test"); ASSERT_EQ((std::string)buf, "{\"count\": 4}"); } TEST(Plugin, BadManifest) { - Context context; Manifest manifest; - ASSERT_THROW(Plugin plugin = context.plugin(manifest), Error); + ASSERT_THROW(Plugin plugin(manifest), Error); } TEST(Plugin, Bytes) { @@ -68,7 +66,6 @@ TEST(Plugin, FunctionExists) { } TEST(Plugin, HostFunction) { - Context context; auto wasm = read("../../wasm/code-functions.wasm"); auto t = std::vector{ValType::I64}; Function hello_world = @@ -82,7 +79,7 @@ TEST(Plugin, HostFunction) { auto functions = std::vector{ hello_world, }; - Plugin plugin = context.plugin(wasm, true, functions); + Plugin plugin(wasm, true, functions); auto buf = plugin.call("count_vowels", "aaa"); ASSERT_EQ(buf.length, 4); ASSERT_EQ((std::string)buf, "test"); diff --git a/dotnet/src/Extism.Sdk/Context.cs b/dotnet/src/Extism.Sdk/Context.cs index c2fea45..0089aeb 100644 --- a/dotnet/src/Extism.Sdk/Context.cs +++ b/dotnet/src/Extism.Sdk/Context.cs @@ -188,4 +188,4 @@ public unsafe class Context : IDisposable return LibExtism.extism_log_file(logPath, logLevel); } -} \ No newline at end of file +} diff --git a/dotnet/src/Extism.Sdk/Plugin.cs b/dotnet/src/Extism.Sdk/Plugin.cs index ac11e4b..4d4626b 100644 --- a/dotnet/src/Extism.Sdk/Plugin.cs +++ b/dotnet/src/Extism.Sdk/Plugin.cs @@ -14,6 +14,18 @@ public class Plugin : IDisposable private readonly HostFunction[] _functions; private int _disposed; + /// + /// Create a and load a plug-in + /// Using this constructor will give the plug-in it's own internal Context + /// + /// A WASM module (wat or wasm) or a JSON encoded manifest. + /// List of host functions expected by the plugin. + /// Enable/Disable WASI. + public static Plugin Create(ReadOnlySpan wasm, HostFunction[] functions, bool withWasi) { + var context = new Context(); + return context.CreatePlugin(wasm, functions, withWasi); + } + internal Plugin(Context context, HostFunction[] functions, int index) { _context = context; @@ -195,4 +207,4 @@ public class Plugin : IDisposable { Dispose(false); } -} \ No newline at end of file +} diff --git a/dotnet/test/Extism.Sdk/BasicTests.cs b/dotnet/test/Extism.Sdk/BasicTests.cs index f5671cb..37612ac 100644 --- a/dotnet/test/Extism.Sdk/BasicTests.cs +++ b/dotnet/test/Extism.Sdk/BasicTests.cs @@ -10,6 +10,17 @@ namespace Extism.Sdk.Tests; public class BasicTests { + [Fact] + public void CountHelloWorldVowelsWithoutContext() + { + var binDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; + var wasm = File.ReadAllBytes(Path.Combine(binDirectory, "code.wasm")); + using var plugin = Plugin.Create(wasm, Array.Empty(), withWasi: true); + + var response = plugin.CallFunction("count_vowels", Encoding.UTF8.GetBytes("Hello World")); + Assert.Equal("{\"count\": 3}", Encoding.UTF8.GetString(response)); + } + [Fact] public void CountHelloWorldVowels() { @@ -59,4 +70,4 @@ public class BasicTests outputs[0].v.i64 = plugin.WriteString(output); } } -} \ No newline at end of file +} diff --git a/elixir/lib/extism/cancel_handle.ex b/elixir/lib/extism/cancel_handle.ex index ae86af0..803a2b3 100644 --- a/elixir/lib/extism/cancel_handle.ex +++ b/elixir/lib/extism/cancel_handle.ex @@ -5,7 +5,7 @@ defmodule Extism.CancelHandle do """ defstruct [ # The actual NIF Resource. PluginIndex and the context - handle: nil, + handle: nil ] def wrap_resource(handle) do diff --git a/elixir/lib/extism/plugin.ex b/elixir/lib/extism/plugin.ex index 1ccd608..74450fa 100644 --- a/elixir/lib/extism/plugin.ex +++ b/elixir/lib/extism/plugin.ex @@ -15,12 +15,25 @@ defmodule Extism.Plugin do } end + @doc """ + Creates a new plugin + """ + def new(manifest, wasi \\ false, context \\ nil) do + ctx = context || Extism.Context.new() + {:ok, manifest_payload} = JSON.encode(manifest) + + case Extism.Native.plugin_new_with_manifest(ctx.ptr, manifest_payload, wasi) do + {:error, err} -> {:error, err} + res -> {:ok, Extism.Plugin.wrap_resource(ctx, res)} + end + end + @doc """ Call a plugin's function by name ## Examples - iex> {:ok, plugin} = Extism.Context.new_plugin(ctx, manifest, false) + iex> {:ok, plugin} = Extism.Plugin.new(manifest, false) iex> {:ok, output} = Extism.Plugin.call(plugin, "count_vowels", "this is a test") # {:ok, "{\"count\": 4}"} diff --git a/elixir/native/extism_nif/src/lib.rs b/elixir/native/extism_nif/src/lib.rs index cf9524f..dd46cc0 100644 --- a/elixir/native/extism_nif/src/lib.rs +++ b/elixir/native/extism_nif/src/lib.rs @@ -62,8 +62,8 @@ fn plugin_new_with_manifest( manifest_payload: String, wasi: bool, ) -> Result { - let context = &ctx.ctx.write().unwrap(); - let result = match Plugin::new(context, manifest_payload, [], wasi) { + let context = ctx.ctx.write().unwrap(); + let result = match Plugin::new(&context, manifest_payload, [], wasi) { Err(e) => Err(to_rustler_error(e)), Ok(plugin) => { let plugin_id = plugin.as_i32(); diff --git a/extism.go b/extism.go index 084be77..15319cb 100644 --- a/extism.go +++ b/extism.go @@ -323,6 +323,18 @@ func update(ctx *Context, plugin int32, data []byte, functions []Function, wasi ) } +// NewPlugin creates a plugin in its own context +func NewPlugin(module io.Reader, functions []Function, wasi bool) (Plugin, error) { + ctx := NewContext() + return ctx.Plugin(module, functions, wasi) +} + +// NewPlugin creates a plugin in its own context from a manifest +func NewPluginFromManifest(manifest Manifest, functions []Function, wasi bool) (Plugin, error) { + ctx := NewContext() + return ctx.PluginFromManifest(manifest, functions, wasi) +} + // PluginFromManifest creates a plugin from a `Manifest` func (ctx *Context) PluginFromManifest(manifest Manifest, functions []Function, wasi bool) (Plugin, error) { data, err := json.Marshal(manifest) diff --git a/go/main.go b/go/main.go index 230a2c9..84faa2e 100644 --- a/go/main.go +++ b/go/main.go @@ -36,9 +36,6 @@ func main() { version := extism.ExtismVersion() fmt.Println("Extism Version: ", version) - ctx := extism.NewContext() - defer ctx.Free() // this will free the context and all associated plugins - // set some input data to provide to the plugin module var data []byte if len(os.Args) > 1 { @@ -49,7 +46,7 @@ func main() { manifest := extism.Manifest{Wasm: []extism.Wasm{extism.WasmFile{Path: "../wasm/code-functions.wasm"}}} f := extism.NewFunction("hello_world", []extism.ValType{extism.I64}, []extism.ValType{extism.I64}, C.hello_world, "Hello again!") defer f.Free() - plugin, err := ctx.PluginFromManifest(manifest, []extism.Function{f}, true) + plugin, err := extism.NewPluginFromManifest(manifest, []extism.Function{f}, true) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/haskell/Example.hs b/haskell/Example.hs index 85c30c0..9dd3c9d 100644 --- a/haskell/Example.hs +++ b/haskell/Example.hs @@ -9,8 +9,7 @@ unwrap (Left (ExtismError msg)) = do main = do let m = manifest [wasmFile "../wasm/code.wasm"] - context <- Extism.newContext - plugin <- unwrap <$> Extism.pluginFromManifest context m False + plugin <- unwrap <$> Extism.createPluginFromManifest m False res <- unwrap <$> Extism.call plugin "count_vowels" (Extism.toByteString "this is a test") putStrLn (Extism.fromByteString res) Extism.free plugin diff --git a/haskell/src/Extism.hs b/haskell/src/Extism.hs index 3a63427..f8de37c 100644 --- a/haskell/src/Extism.hs +++ b/haskell/src/Extism.hs @@ -79,6 +79,12 @@ plugin c wasm useWasi = return $ Left (ExtismError e) else return $ Right (Plugin c p)) + +-- | Create a 'Plugin' with its own 'Context' +createPlugin :: B.ByteString -> Bool -> IO (Result Plugin) +createPlugin c useWasi = do + ctx <- newContext + plugin ctx c useWasi -- | Create a 'Plugin' from a 'Manifest' pluginFromManifest :: Context -> Manifest -> Bool -> IO (Result Plugin) @@ -86,6 +92,12 @@ pluginFromManifest ctx manifest useWasi = let wasm = toByteString $ toString manifest in plugin ctx wasm useWasi +-- | Create a 'Plugin' with its own 'Context' from a 'Manifest' +createPluginFromManifest :: Manifest -> Bool -> IO (Result Plugin) +createPluginFromManifest manifest useWasi = do + ctx <- newContext + pluginFromManifest ctx manifest useWasi + -- | Update a 'Plugin' with a new WASM module update :: Plugin -> B.ByteString -> Bool -> IO (Result ()) update (Plugin (Context ctx) id) wasm useWasi = diff --git a/haskell/test/Test.hs b/haskell/test/Test.hs index f66d5a3..88190fb 100644 --- a/haskell/test/Test.hs +++ b/haskell/test/Test.hs @@ -9,45 +9,45 @@ unwrap (Left (ExtismError msg)) = defaultManifest = manifest [wasmFile "../../wasm/code.wasm"] -initPlugin :: Context -> IO Plugin -initPlugin context = - Extism.pluginFromManifest context defaultManifest False >>= unwrap +initPlugin :: Maybe Context -> IO Plugin +initPlugin Nothing = + Extism.createPluginFromManifest defaultManifest False >>= unwrap +initPlugin (Just ctx) = + Extism.pluginFromManifest ctx defaultManifest False >>= unwrap pluginFunctionExists = do - withContext (\ctx -> do - p <- initPlugin ctx - exists <- functionExists p "count_vowels" - assertBool "function exists" exists - exists' <- functionExists p "function_doesnt_exist" - assertBool "function doesn't exist" (not exists')) + p <- initPlugin Nothing + exists <- functionExists p "count_vowels" + assertBool "function exists" exists + exists' <- functionExists p "function_doesnt_exist" + assertBool "function doesn't exist" (not exists') checkCallResult p = do - res <- call p "count_vowels" (toByteString "this is a test") >>= unwrap - assertEqual "count vowels output" "{\"count\": 4}" (fromByteString res) + res <- call p "count_vowels" (toByteString "this is a test") >>= unwrap + assertEqual "count vowels output" "{\"count\": 4}" (fromByteString res) pluginCall = do - withContext (\ctx -> do - p <- initPlugin ctx - checkCallResult p) + p <- initPlugin Nothing + checkCallResult p pluginMultiple = do - withContext (\ctx -> do - p <- initPlugin ctx + withContext(\ctx -> do + p <- initPlugin (Just ctx) checkCallResult p - q <- initPlugin ctx - r <- initPlugin ctx + q <- initPlugin (Just ctx) + r <- initPlugin (Just ctx) checkCallResult q checkCallResult r) pluginUpdate = do withContext (\ctx -> do - p <- initPlugin ctx + p <- initPlugin (Just ctx) updateManifest p defaultManifest True >>= unwrap checkCallResult p) pluginConfig = do withContext (\ctx -> do - p <- initPlugin ctx + p <- initPlugin (Just ctx) b <- setConfig p [("a", Just "1"), ("b", Just "2"), ("c", Just "3"), ("d", Nothing)] assertBool "set config" b) diff --git a/java/src/main/java/org/extism/sdk/Plugin.java b/java/src/main/java/org/extism/sdk/Plugin.java index 9ab7a0e..f53d4cd 100644 --- a/java/src/main/java/org/extism/sdk/Plugin.java +++ b/java/src/main/java/org/extism/sdk/Plugin.java @@ -62,6 +62,16 @@ public class Plugin implements AutoCloseable { this(context, serialize(manifest), withWASI, functions); } + + public Plugin(byte[] manifestBytes, boolean withWASI, HostFunction[] functions) { + this(new Context(), manifestBytes, withWASI, functions); + } + + + public Plugin(Manifest manifest, boolean withWASI, HostFunction[] functions) { + this(new Context(), serialize(manifest), withWASI, functions); + } + private static byte[] serialize(Manifest manifest) { Objects.requireNonNull(manifest, "manifest"); return JsonSerde.toJson(manifest).getBytes(StandardCharsets.UTF_8); diff --git a/node/example.js b/node/example.js index 653ea72..543736d 100644 --- a/node/example.js +++ b/node/example.js @@ -1,6 +1,5 @@ const { - withContext, - Context, + Plugin, HostFunction, ValType, } = require("./dist/index.js"); @@ -13,29 +12,31 @@ function f(currentPlugin, inputs, outputs, userData) { outputs[0] = inputs[0]; } -let hello_world = new HostFunction( +const hello_world = new HostFunction( "hello_world", [ValType.I64], [ValType.I64], f, - "Hello again!" + "Hello again!", ); -let functions = [hello_world]; +async function main() { + const functions = [hello_world]; -withContext(async function (context) { - let wasm = readFileSync("../wasm/code-functions.wasm"); - let p = context.plugin(wasm, true, functions); + const wasm = readFileSync("../wasm/code-functions.wasm"); + const p = new Plugin(wasm, true, functions); if (!p.functionExists("count_vowels")) { console.log("no function 'count_vowels' in wasm"); process.exit(1); } - let buf = await p.call("count_vowels", process.argv[2] || "this is a test"); + const buf = await p.call("count_vowels", process.argv[2] || "this is a test"); console.log(JSON.parse(buf.toString())["count"]); p.free(); -}); +} + +main(); // or, use a context like this: // let ctx = new Context(); diff --git a/node/src/index.ts b/node/src/index.ts index 6dd3734..7fc4695 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -104,7 +104,7 @@ interface LibExtism { data_len: number, functions: Buffer, nfunctions: number, - wasi: boolean + wasi: boolean, ) => number; extism_plugin_update: ( ctx: Buffer, @@ -113,7 +113,7 @@ interface LibExtism { data_len: number, functions: Buffer, nfunctions: number, - wasi: boolean + wasi: boolean, ) => boolean; extism_error: (ctx: Buffer, plugin_id: number) => string; extism_plugin_call: ( @@ -121,7 +121,7 @@ interface LibExtism { plugin_id: number, func: string, input: string, - input_len: number + input_len: number, ) => number; extism_plugin_output_length: (ctx: Buffer, plugin_id: number) => number; extism_plugin_output_data: (ctx: Buffer, plugin_id: number) => Uint8Array; @@ -129,13 +129,13 @@ interface LibExtism { extism_plugin_function_exists: ( ctx: Buffer, plugin_id: number, - func: string + func: string, ) => boolean; extism_plugin_config: ( ctx: Buffer, plugin_id: number, data: string | Buffer, - data_len: number + data_len: number, ) => void; extism_plugin_free: (ctx: Buffer, plugin_id: number) => void; extism_context_reset: (ctx: Buffer) => void; @@ -148,7 +148,7 @@ interface LibExtism { nOutputs: number, f: Buffer, user_data: Buffer | null, - free: Buffer | null + free: Buffer | null, ) => Buffer; extism_function_set_namespace: (f: Buffer, s: string) => void; extism_function_free: (f: Buffer) => void; @@ -321,9 +321,9 @@ export class Context { manifest: ManifestData, wasi: boolean = false, functions: HostFunction[] = [], - config?: PluginConfig + config?: PluginConfig, ) { - return new Plugin(this, manifest, wasi, functions, config); + return new Plugin(manifest, wasi, functions, config, this); } /** @@ -385,7 +385,7 @@ export class CurrentPlugin { return Buffer.from( lib.extism_current_plugin_memory(this.pointer).buffer, offset, - length + length, ); } @@ -442,7 +442,7 @@ export class CurrentPlugin { * @param input - The input to read */ inputBytes(input: typeof Val): Buffer { - return this.memory(input.v.i64) + return this.memory(input.v.i64); } /** @@ -450,7 +450,7 @@ export class CurrentPlugin { * @param input - The input to read */ inputString(input: typeof Val): string { - return this.memory(input.v.i64).toString() + return this.memory(input.v.i64).toString(); } } @@ -489,7 +489,7 @@ export class HostFunction { nInputs: number, outputs: Buffer, nOutputs: number, - user_data + user_data, ) => { let inputArr = []; let outputArr = []; @@ -506,13 +506,13 @@ export class HostFunction { new CurrentPlugin(currentPlugin), inputArr, outputArr, - ...this.userData + ...this.userData, ); for (var i = 0; i < nOutputs; i++) { Val.set(outputs, i, outputArr[i]); } - } + }, ); this.name = name; this.inputs = new ValTypeArray(inputs); @@ -525,23 +525,23 @@ export class HostFunction { this.outputs.length, this.callback, null, - null + null, ); this.userData = userData; functionRegistry.register(this, this.pointer, this.pointer); } - /** + /** * Set function namespace */ setNamespace(name: string) { if (this.pointer !== null) { - lib.extism_function_set_namespace(this.pointer, name) + lib.extism_function_set_namespace(this.pointer, name); } } - withNamespace(name: string) : HostFunction { - this.setNamespace(name) + withNamespace(name: string): HostFunction { + this.setNamespace(name); return this; } @@ -560,18 +560,18 @@ export class HostFunction { } /** - * CancelHandle is used to cancel a running Plugin - */ + * CancelHandle is used to cancel a running Plugin + */ export class CancelHandle { - handle: Buffer + handle: Buffer; constructor(handle: Buffer) { this.handle = handle; } /** - * Cancel execution of the Plugin associated with the CancelHandle - */ + * Cancel execution of the Plugin associated with the CancelHandle + */ cancel(): boolean { return lib.extism_plugin_cancel(this.handle); } @@ -589,19 +589,22 @@ export class Plugin { /** * Constructor for a plugin. @see {@link Context#plugin}. * - * @param ctx - The context to manage this plugin * @param manifest - The {@link Manifest} * @param wasi - Set to true to enable WASI support * @param functions - An array of {@link HostFunction} * @param config - The plugin config + * @param ctx - The context to manage this plugin, or null to use a new context */ constructor( - ctx: Context, manifest: ManifestData, wasi: boolean = false, functions: HostFunction[] = [], - config?: PluginConfig + config?: PluginConfig, + ctx: Context | null = null, ) { + if (ctx == null) { + ctx = new Context(); + } let dataRaw: string | Buffer; if (Buffer.isBuffer(manifest) || typeof manifest === "string") { dataRaw = manifest; @@ -621,7 +624,7 @@ export class Plugin { Buffer.byteLength(dataRaw, "utf-8"), this.functions, functions.length, - wasi + wasi, ); if (plugin < 0) { var err = lib.extism_error(ctx.pointer, -1); @@ -640,7 +643,7 @@ export class Plugin { ctx.pointer, this.id, s, - Buffer.byteLength(s, "utf-8") + Buffer.byteLength(s, "utf-8"), ); } } @@ -666,7 +669,7 @@ export class Plugin { manifest: ManifestData, wasi: boolean = false, functions: HostFunction[] = [], - config?: PluginConfig + config?: PluginConfig, ) { let dataRaw: string | Buffer; if (Buffer.isBuffer(manifest) || typeof manifest === "string") { @@ -688,7 +691,7 @@ export class Plugin { Buffer.byteLength(dataRaw, "utf-8"), this.functions, functions.length, - wasi + wasi, ); if (!ok) { var err = lib.extism_error(this.ctx.pointer, -1); @@ -704,7 +707,7 @@ export class Plugin { this.ctx.pointer, this.id, s, - Buffer.byteLength(s, "utf-8") + Buffer.byteLength(s, "utf-8"), ); } } @@ -721,7 +724,7 @@ export class Plugin { return lib.extism_plugin_function_exists( this.ctx.pointer, this.id, - functionName + functionName, ); } @@ -739,7 +742,7 @@ export class Plugin { * * @param functionName - The name of the function * @param input - The input data - *@returns A Buffer repreesentation of the output + * @returns A Buffer repreesentation of the output */ async call(functionName: string, input: string | Buffer): Promise { return new Promise((resolve, reject) => { @@ -749,7 +752,7 @@ export class Plugin { this.id, functionName, input.toString(), - Buffer.byteLength(input, "utf-8") + Buffer.byteLength(input, "utf-8"), ); if (rc !== 0) { var err = lib.extism_error(this.ctx.pointer, this.id); @@ -763,7 +766,7 @@ export class Plugin { var buf = Buffer.from( lib.extism_plugin_output_data(this.ctx.pointer, this.id).buffer, 0, - out_len + out_len, ); resolve(buf); }); diff --git a/ocaml/bin/main.ml b/ocaml/bin/main.ml index a14f3d5..ac72b45 100644 --- a/ocaml/bin/main.ml +++ b/ocaml/bin/main.ml @@ -4,10 +4,9 @@ open Cmdliner let read_stdin () = In_channel.input_all stdin let main file func_name input = - with_context @@ fun ctx -> let input = if String.equal input "-" then read_stdin () else input in let file = In_channel.with_open_bin file In_channel.input_all in - let plugin = Plugin.create ctx file ~wasi:true |> Result.get_ok in + let plugin = Plugin.create file ~wasi:true |> Result.get_ok in let res = Plugin.call plugin ~name:func_name input |> Result.get_ok in print_endline res diff --git a/ocaml/lib/extism.mli b/ocaml/lib/extism.mli index 1fa3314..7ae5c6e 100644 --- a/ocaml/lib/extism.mli +++ b/ocaml/lib/extism.mli @@ -208,7 +208,7 @@ module Plugin : sig ?config:Manifest.config -> ?wasi:bool -> ?functions:Function.t list -> - Context.t -> + ?context:Context.t -> string -> (t, Error.t) result (** Make a new plugin from raw WebAssembly or JSON encoded manifest *) @@ -216,7 +216,7 @@ module Plugin : sig val of_manifest : ?wasi:bool -> ?functions:Function.t list -> - Context.t -> + ?context:Context.t -> Manifest.t -> (t, Error.t) result (** Make a new plugin from a [Manifest] *) diff --git a/ocaml/lib/plugin.ml b/ocaml/lib/plugin.ml index bf28acf..8cdc5a3 100644 --- a/ocaml/lib/plugin.ml +++ b/ocaml/lib/plugin.ml @@ -26,7 +26,8 @@ let free t = if not (Ctypes.is_null t.ctx.pointer) then Bindings.extism_plugin_free t.ctx.pointer t.id -let create ?config ?(wasi = false) ?(functions = []) ctx wasm = +let create ?config ?(wasi = false) ?(functions = []) ?context wasm = + let ctx = match context with Some c -> c | None -> Context.create () in let func_ptrs = List.map (fun x -> x.Function.pointer) functions in let arr = Ctypes.CArray.of_list Ctypes.(ptr void) func_ptrs in let n_funcs = Ctypes.CArray.length arr in @@ -48,16 +49,15 @@ let create ?config ?(wasi = false) ?(functions = []) ctx wasm = let () = Gc.finalise free t in Ok t -let of_manifest ?wasi ?functions ctx manifest = +let of_manifest ?wasi ?functions ?context manifest = let data = Manifest.to_json manifest in - create ctx ?wasi ?functions data + create ?wasi ?functions ?context data let%test "free plugin" = let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in - with_context (fun ctx -> - let plugin = of_manifest ctx manifest |> Error.unwrap in - free plugin; - true) + let plugin = of_manifest manifest |> Error.unwrap in + free plugin; + true let update plugin ?config ?(wasi = false) ?(functions = []) wasm = let { id; ctx; _ } = plugin in @@ -85,11 +85,10 @@ let update_manifest plugin ?wasi manifest = let%test "update plugin manifest and config" = let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in - with_context (fun ctx -> - let config = [ ("a", Some "1") ] in - let plugin = of_manifest ctx manifest |> Error.unwrap in - let manifest = Manifest.with_config manifest config in - update_manifest plugin manifest |> Result.is_ok) + let config = [ ("a", Some "1") ] in + let plugin = of_manifest manifest |> Error.unwrap in + let manifest = Manifest.with_config manifest config in + update_manifest plugin manifest |> Result.is_ok let call' f { id; ctx; _ } ~name input len = let rc = f ctx.pointer id name input len in @@ -114,11 +113,10 @@ let call_bigstring (t : t) ~name input = let%test "call_bigstring" = let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in - with_context (fun ctx -> - let plugin = of_manifest ctx manifest |> Error.unwrap in - call_bigstring plugin ~name:"count_vowels" - (Bigstringaf.of_string ~off:0 ~len:14 "this is a test") - |> Error.unwrap |> Bigstringaf.to_string = "{\"count\": 4}") + let plugin = of_manifest manifest |> Error.unwrap in + call_bigstring plugin ~name:"count_vowels" + (Bigstringaf.of_string ~off:0 ~len:14 "this is a test") + |> Error.unwrap |> Bigstringaf.to_string = "{\"count\": 4}" let call (t : t) ~name input = let len = String.length input in @@ -127,10 +125,9 @@ let call (t : t) ~name input = let%test "call" = let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in - with_context (fun ctx -> - let plugin = of_manifest ctx manifest |> Error.unwrap in - call plugin ~name:"count_vowels" "this is a test" - |> Error.unwrap = "{\"count\": 4}") + let plugin = of_manifest manifest |> Error.unwrap in + call plugin ~name:"count_vowels" "this is a test" + |> Error.unwrap = "{\"count\": 4}" let%test "call_functions" = let open Types.Val_type in @@ -147,22 +144,18 @@ let%test "call_functions" = in let functions = [ hello_world ] in let manifest = Manifest.(create [ Wasm.file "test/code-functions.wasm" ]) in - with_context (fun ctx -> - let plugin = - of_manifest ctx manifest ~functions ~wasi:true |> Error.unwrap - in - call plugin ~name:"count_vowels" "this is a test" - |> Error.unwrap = "{\"count\": 4}") + let plugin = of_manifest manifest ~functions ~wasi:true |> Error.unwrap in + call plugin ~name:"count_vowels" "this is a test" + |> Error.unwrap = "{\"count\": 4}" let function_exists { id; ctx; _ } name = Bindings.extism_plugin_function_exists ctx.pointer id name let%test "function exists" = let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in - with_context (fun ctx -> - let plugin = of_manifest ctx manifest |> Error.unwrap in - function_exists plugin "count_vowels" - && not (function_exists plugin "function_does_not_exist")) + let plugin = of_manifest manifest |> Error.unwrap in + function_exists plugin "count_vowels" + && not (function_exists plugin "function_does_not_exist") module Cancel_handle = struct type t = { inner : unit Ctypes.ptr } diff --git a/python/example.py b/python/example.py index 27c42f5..a772a0f 100644 --- a/python/example.py +++ b/python/example.py @@ -5,11 +5,11 @@ import hashlib import pathlib sys.path.append(".") -from extism import Context, Function, host_fn, ValType +from extism import Function, host_fn, ValType, Plugin @host_fn -def hello_world(plugin, input_, output, context, a_string): +def hello_world(plugin, input_, output, a_string): print("Hello from Python!") print(a_string) print(input_) @@ -33,24 +33,20 @@ def main(args): ) wasm = wasm_file_path.read_bytes() hash = hashlib.sha256(wasm).hexdigest() - config = {"wasm": [{"data": wasm, "hash": hash}], "memory": {"max": 5}} + manifest = {"wasm": [{"data": wasm, "hash": hash}], "memory": {"max": 5}} - # a Context provides a scope for plugins to be managed within. creating multiple contexts - # is expected and groups plugins based on source/tenant/lifetime etc. - with Context() as context: - functions = [ - Function( - "hello_world", - [ValType.I64], - [ValType.I64], - hello_world, - context, - "Hello again!", - ) - ] - plugin = context.plugin(config, wasi=True, functions=functions) - # Call `count_vowels` - wasm_vowel_count = json.loads(plugin.call("count_vowels", data)) + functions = [ + Function( + "hello_world", + [ValType.I64], + [ValType.I64], + hello_world, + "Hello again!", + ) + ] + plugin = Plugin(manifest, wasi=True, functions=functions) + # Call `count_vowels` + wasm_vowel_count = json.loads(plugin.call("count_vowels", data)) print("Number of vowels:", wasm_vowel_count["count"]) diff --git a/python/extism/extism.py b/python/extism/extism.py index fd93ac1..e0dc38f 100644 --- a/python/extism/extism.py +++ b/python/extism/extism.py @@ -193,7 +193,9 @@ class Context: Plugin The created plugin """ - return Plugin(self, manifest, wasi, config, functions) + return Plugin( + manifest, context=self, wasi=wasi, config=config, functions=functions + ) class Function: @@ -247,16 +249,19 @@ class Plugin: def __init__( self, - context: Context, plugin: Union[str, bytes, dict], + context=None, wasi=False, config=None, functions=None, ): """ - Construct a Plugin. Please use Context#plugin instead. + Construct a Plugin """ + if context is None: + context = Context() + wasm = _wasm(plugin) # Register plugin diff --git a/python/tests/test_extism.py b/python/tests/test_extism.py index 1dd8e6a..f9cffa6 100644 --- a/python/tests/test_extism.py +++ b/python/tests/test_extism.py @@ -104,9 +104,11 @@ class TestExtism(unittest.TestCase): with extism.Context() as ctx: plugin = ctx.plugin(self._loop_manifest()) cancel_handle = plugin.cancel_handle() + def cancel(handle): time.sleep(0.5) handle.cancel() + Thread(target=cancel, args=[cancel_handle]).run() self.assertRaises(extism.Error, lambda: plugin.call("infinite_loop", b"")) diff --git a/ruby/example.rb b/ruby/example.rb index 0c5cefd..c6ac749 100644 --- a/ruby/example.rb +++ b/ruby/example.rb @@ -1,17 +1,11 @@ require "./lib/extism" require "json" -# a Context provides a scope for plugins to be managed within. creating multiple contexts -# is expected and groups plugins based on source/tenant/lifetime etc. -# We recommend you use `Extism.with_context` unless you have a reason to keep your context around. -# If you do you can create a context with `Extism#new`, example: `ctx = Extism.new` -Extism.with_context do |ctx| - manifest = { - :wasm => [{ :path => "../wasm/code.wasm" }], - } +manifest = { + :wasm => [{ :path => "../wasm/code.wasm" }], +} - plugin = ctx.plugin(manifest) - res = JSON.parse(plugin.call("count_vowels", ARGV[0] || "this is a test")) +plugin = Extism::Plugin.new(manifest) +res = JSON.parse(plugin.call("count_vowels", ARGV[0] || "this is a test")) - puts res["count"] -end +puts res["count"] diff --git a/ruby/lib/extism.rb b/ruby/lib/extism.rb index af5e881..d6cb01d 100644 --- a/ruby/lib/extism.rb +++ b/ruby/lib/extism.rb @@ -91,7 +91,7 @@ module Extism # @param config [Hash] The plugin config # @return [Plugin] def plugin(wasm, wasi = false, config = nil) - Plugin.new(self, wasm, wasi, config) + Plugin.new(wasm, context=self, wasi, config) end end @@ -132,11 +132,14 @@ module Extism # Intialize a plugin # # @see Extism::Context#plugin - # @param context [Context] The context to manager this plugin # @param wasm [Hash, String] The manifest or WASM binary. See https://extism.org/docs/concepts/manifest/. + # @param context [Context] The context to manager this plugin # @param wasi [Boolean] Enable WASI support # @param config [Hash] The plugin config - def initialize(context, wasm, wasi = false, config = nil) + def initialize(wasm, context = nil, wasi = false, config = nil) + if context.nil? then + context = Context.new + end @context = context if wasm.class == Hash wasm = JSON.generate(wasm) diff --git a/runtime/build.rs b/runtime/build.rs index 0c27641..9d026fc 100644 --- a/runtime/build.rs +++ b/runtime/build.rs @@ -16,7 +16,7 @@ fn main() { .rename_item("Context", "ExtismContext") .rename_item("ValType", "ExtismValType") .rename_item("ValUnion", "ExtismValUnion") - .rename_item("Plugin", "ExtismCurrentPlugin") + .rename_item("Internal", "ExtismCurrentPlugin") .with_style(cbindgen::Style::Type) .generate() { diff --git a/runtime/extism.h b/runtime/extism.h index 9738454..7a085c4 100644 --- a/runtime/extism.h +++ b/runtime/extism.h @@ -54,7 +54,7 @@ typedef struct ExtismCancelHandle ExtismCancelHandle; typedef struct ExtismFunction ExtismFunction; /** - * Plugin contains everything needed to execute a WASM function + * Internal stores data that is available to the caller in PDK functions */ typedef struct ExtismCurrentPlugin ExtismCurrentPlugin; diff --git a/runtime/src/context.rs b/runtime/src/context.rs index a4b692f..c8e5c6e 100644 --- a/runtime/src/context.rs +++ b/runtime/src/context.rs @@ -1,4 +1,3 @@ -use std::cell::UnsafeCell; use std::collections::{BTreeMap, VecDeque}; use crate::*; @@ -8,7 +7,7 @@ static mut TIMER: std::sync::Mutex> = std::sync::Mutex::new(None); /// A `Context` is used to store and manage plugins pub struct Context { /// Plugin registry - pub plugins: BTreeMap>, + pub plugins: BTreeMap, /// Error message pub error: Option, @@ -91,7 +90,7 @@ impl Context { return -1; } }; - self.plugins.insert(id, UnsafeCell::new(plugin)); + self.plugins.insert(id, plugin); id } @@ -127,7 +126,7 @@ impl Context { /// Get a plugin from the context pub fn plugin(&mut self, id: PluginIndex) -> Option<*mut Plugin> { match self.plugins.get_mut(&id) { - Some(x) => Some(x.get_mut()), + Some(x) => Some(x), None => None, } } diff --git a/runtime/src/function.rs b/runtime/src/function.rs index 849966c..29e4f8e 100644 --- a/runtime/src/function.rs +++ b/runtime/src/function.rs @@ -169,7 +169,7 @@ impl Function { ) -> Function where F: 'static - + Fn(&mut crate::Plugin, &[Val], &mut [Val], UserData) -> Result<(), Error> + + Fn(&mut Internal, &[Val], &mut [Val], UserData) -> Result<(), Error> + Sync + Send, { @@ -182,7 +182,7 @@ impl Function { returns.into_iter().map(wasmtime::ValType::from), ), f: std::sync::Arc::new(move |mut caller, inp, outp| { - f(caller.data_mut().plugin_mut(), inp, outp, data.make_copy()) + f(caller.data_mut(), inp, outp, data.make_copy()) }), namespace: None, _user_data: std::sync::Arc::new(user_data), diff --git a/runtime/src/internal.rs b/runtime/src/internal.rs new file mode 100644 index 0000000..bf47a79 --- /dev/null +++ b/runtime/src/internal.rs @@ -0,0 +1,132 @@ +use std::collections::BTreeMap; + +use crate::*; + +/// Internal stores data that is available to the caller in PDK functions +pub struct Internal { + /// Call input length + pub input_length: usize, + + /// Pointer to call input + pub input: *const u8, + + /// Memory offset that points to the output + pub output_offset: usize, + + /// Length of output in memory + pub output_length: usize, + + /// WASI context + pub wasi: Option, + + /// Keep track of the status from the last HTTP request + pub http_status: u16, + + /// Store plugin-specific error messages + pub last_error: std::cell::RefCell>, + + /// Plugin variables + pub vars: BTreeMap>, + + /// A pointer to the plugin memory, this should mostly be used from the PDK + pub memory: *mut PluginMemory, +} + +/// InternalExt provides a unified way of acessing `memory`, `store` and `internal` values +pub trait InternalExt { + fn memory(&self) -> &PluginMemory; + fn memory_mut(&mut self) -> &mut PluginMemory; + + fn store(&self) -> &Store { + self.memory().store() + } + + fn store_mut(&mut self) -> &mut Store { + self.memory_mut().store_mut() + } + + fn internal(&self) -> &Internal { + self.store().data() + } + + fn internal_mut(&mut self) -> &mut Internal { + self.store_mut().data_mut() + } +} + +/// WASI context +pub struct Wasi { + /// wasi + pub ctx: wasmtime_wasi::WasiCtx, + + /// wasi-nn + #[cfg(feature = "nn")] + pub nn: wasmtime_wasi_nn::WasiNnCtx, + #[cfg(not(feature = "nn"))] + pub nn: (), +} + +impl Internal { + pub(crate) fn new(manifest: &Manifest, wasi: bool) -> Result { + let wasi = if wasi { + let auth = wasmtime_wasi::ambient_authority(); + let mut ctx = wasmtime_wasi::WasiCtxBuilder::new(); + for (k, v) in manifest.as_ref().config.iter() { + ctx = ctx.env(k, v)?; + } + + if let Some(a) = &manifest.as_ref().allowed_paths { + for (k, v) in a.iter() { + let d = wasmtime_wasi::Dir::open_ambient_dir(k, auth)?; + ctx = ctx.preopened_dir(d, v)?; + } + } + + #[cfg(feature = "nn")] + let nn = wasmtime_wasi_nn::WasiNnCtx::new()?; + + #[cfg(not(feature = "nn"))] + #[allow(clippy::let_unit_value)] + let nn = (); + + Some(Wasi { + ctx: ctx.build(), + nn, + }) + } else { + None + }; + + Ok(Internal { + input_length: 0, + output_offset: 0, + output_length: 0, + input: std::ptr::null(), + wasi, + memory: std::ptr::null_mut(), + http_status: 0, + last_error: std::cell::RefCell::new(None), + vars: BTreeMap::new(), + }) + } + + pub fn set_error(&self, e: impl std::fmt::Debug) { + debug!("Set error: {:?}", e); + *self.last_error.borrow_mut() = Some(error_string(e)); + } + + /// Unset `last_error` field + pub fn clear_error(&self) { + *self.last_error.borrow_mut() = None; + } +} + +impl InternalExt for Internal { + fn memory(&self) -> &PluginMemory { + unsafe { &*self.memory } + } + + fn memory_mut(&mut self) -> &mut PluginMemory { + unsafe { &mut *self.memory } + } +} diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 19fd434..2ef5d15 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -3,6 +3,7 @@ pub(crate) use wasmtime::*; mod context; mod function; +mod internal; pub mod manifest; mod memory; pub(crate) mod pdk; @@ -13,9 +14,10 @@ mod timer; pub use context::Context; pub use function::{Function, UserData, Val, ValType}; +pub use internal::{Internal, InternalExt, Wasi}; pub use manifest::Manifest; pub use memory::{MemoryBlock, PluginMemory, ToMemoryBlock}; -pub use plugin::{Internal, Plugin, Wasi}; +pub use plugin::Plugin; pub use plugin_ref::PluginRef; pub(crate) use timer::{Timer, TimerAction}; diff --git a/runtime/src/memory.rs b/runtime/src/memory.rs index 4fa7015..98d62fe 100644 --- a/runtime/src/memory.rs +++ b/runtime/src/memory.rs @@ -6,13 +6,26 @@ use pretty_hex::PrettyHex; /// Handles memory for plugins pub struct PluginMemory { - pub store: Store, + /// wasmtime Store + pub store: Option>, + + /// WASM memory pub memory: Memory, + + /// Tracks allocated blocks pub live_blocks: BTreeMap, + + /// Tracks free blocks pub free: Vec, + + /// Tracks current offset in memory pub position: usize, + + /// Extism manifest + pub manifest: Manifest, } +/// `ToMemoryBlock` is used to convert from Rust values to blocks of WASM memory pub trait ToMemoryBlock { fn to_memory_block(&self, mem: &PluginMemory) -> Result; } @@ -49,26 +62,54 @@ const BLOCK_SIZE_THRESHOLD: usize = 32; impl PluginMemory { /// Create memory for a plugin - pub fn new(store: Store, memory: Memory) -> Self { + pub fn new(store: Store, memory: Memory, manifest: Manifest) -> Self { PluginMemory { free: Vec::new(), live_blocks: BTreeMap::new(), - store, + store: Some(store), memory, position: 1, + manifest, } } + pub fn store(&self) -> &Store { + self.store.as_ref().unwrap() + } + + pub fn store_mut(&mut self) -> &mut Store { + self.store.as_mut().unwrap() + } + + /// Moves module to a new store + pub fn reinstantiate(&mut self) -> Result<(), Error> { + if let Some(store) = self.store.take() { + let engine = store.engine().clone(); + let internal = store.into_data(); + let mut store = Store::new(&engine, internal); + store.epoch_deadline_callback(|_internal| Err(Error::msg("timeout"))); + self.memory = Memory::new( + &mut store, + MemoryType::new(4, self.manifest.as_ref().memory.max_pages), + )?; + self.store = Some(store); + } + + self.reset(); + + Ok(()) + } + /// Write byte to memory pub(crate) fn store_u8(&mut self, offs: usize, data: u8) -> Result<(), MemoryAccessError> { trace!("store_u8: offset={offs} data={data:#04x}"); if offs >= self.size() { // This should raise MemoryAccessError let buf = &mut [0]; - self.memory.read(&self.store, offs, buf)?; + self.memory.read(&self.store.as_ref().unwrap(), offs, buf)?; return Ok(()); } - self.memory.data_mut(&mut self.store)[offs] = data; + self.memory.data_mut(&mut self.store.as_mut().unwrap())[offs] = data; Ok(()) } @@ -78,10 +119,10 @@ impl PluginMemory { if offs >= self.size() { // This should raise MemoryAccessError let buf = &mut [0]; - self.memory.read(&self.store, offs, buf)?; + self.memory.read(&self.store.as_ref().unwrap(), offs, buf)?; return Ok(0); } - Ok(self.memory.data(&self.store)[offs]) + Ok(self.memory.data(&self.store.as_ref().unwrap())[offs]) } /// Write u64 to memory @@ -112,7 +153,7 @@ impl PluginMemory { let pos = pos.to_memory_block(self)?; assert!(data.as_ref().len() <= pos.length); self.memory - .write(&mut self.store, pos.offset, data.as_ref())?; + .write(&mut self.store.as_mut().unwrap(), pos.offset, data.as_ref())?; Ok(()) } @@ -120,18 +161,19 @@ impl PluginMemory { pub fn read(&self, pos: impl ToMemoryBlock, mut data: impl AsMut<[u8]>) -> Result<(), Error> { let pos = pos.to_memory_block(self)?; assert!(data.as_mut().len() <= pos.length); - self.memory.read(&self.store, pos.offset, data.as_mut())?; + self.memory + .read(&self.store.as_ref().unwrap(), pos.offset, data.as_mut())?; Ok(()) } /// Size of memory in bytes pub fn size(&self) -> usize { - self.memory.data_size(&self.store) + self.memory.data_size(&self.store.as_ref().unwrap()) } /// Size of memory in pages pub fn pages(&self) -> u32 { - self.memory.size(&self.store) as u32 + self.memory.size(&self.store.as_ref().unwrap()) as u32 } /// Reserve `n` bytes of memory @@ -175,7 +217,8 @@ impl PluginMemory { debug!("Requesting {pages_needed} more pages"); // This will fail if we've already allocated the maximum amount of memory allowed - self.memory.grow(&mut self.store, pages_needed)?; + self.memory + .grow(&mut self.store.as_mut().unwrap(), pages_needed)?; } let mem = MemoryBlock { @@ -237,7 +280,7 @@ impl PluginMemory { /// Log entire memory as hexdump using the `trace` log level pub fn dump(&self) { - let data = self.memory.data(&self.store); + let data = self.memory.data(self.store.as_ref().unwrap()); trace!("{:?}", data[..self.position].hex_dump()); } @@ -251,34 +294,34 @@ impl PluginMemory { /// Get memory as a slice of bytes pub fn data(&self) -> &[u8] { - self.memory.data(&self.store) + self.memory.data(self.store.as_ref().unwrap()) } /// Get memory as a mutable slice of bytes pub fn data_mut(&mut self) -> &mut [u8] { - self.memory.data_mut(&mut self.store) + self.memory.data_mut(self.store.as_mut().unwrap()) } /// Get bytes occupied by the provided memory handle pub fn get(&self, handle: impl ToMemoryBlock) -> Result<&[u8], Error> { let handle = handle.to_memory_block(self)?; - Ok(&self.memory.data(&self.store)[handle.offset..handle.offset + handle.length]) + Ok(&self.memory.data(self.store.as_ref().unwrap()) + [handle.offset..handle.offset + handle.length]) } /// Get mutable bytes occupied by the provided memory handle pub fn get_mut(&mut self, handle: impl ToMemoryBlock) -> Result<&mut [u8], Error> { let handle = handle.to_memory_block(self)?; - Ok( - &mut self.memory.data_mut(&mut self.store) - [handle.offset..handle.offset + handle.length], - ) + Ok(&mut self.memory.data_mut(self.store.as_mut().unwrap()) + [handle.offset..handle.offset + handle.length]) } /// Get str occupied by the provided memory handle pub fn get_str(&self, handle: impl ToMemoryBlock) -> Result<&str, Error> { let handle = handle.to_memory_block(self)?; Ok(std::str::from_utf8( - &self.memory.data(&self.store)[handle.offset..handle.offset + handle.length], + &self.memory.data(self.store.as_ref().unwrap()) + [handle.offset..handle.offset + handle.length], )?) } @@ -286,7 +329,7 @@ impl PluginMemory { pub fn get_mut_str(&mut self, handle: impl ToMemoryBlock) -> Result<&mut str, Error> { let handle = handle.to_memory_block(self)?; Ok(std::str::from_utf8_mut( - &mut self.memory.data_mut(&mut self.store) + &mut self.memory.data_mut(self.store.as_mut().unwrap()) [handle.offset..handle.offset + handle.length], )?) } @@ -294,7 +337,11 @@ impl PluginMemory { /// Pointer to the provided memory handle pub fn ptr(&self, handle: impl ToMemoryBlock) -> Result<*mut u8, Error> { let handle = handle.to_memory_block(self)?; - Ok(unsafe { self.memory.data_ptr(&self.store).add(handle.offset) }) + Ok(unsafe { + self.memory + .data_ptr(&self.store.as_ref().unwrap()) + .add(handle.offset) + }) } /// Get the length of the block starting at `offs` diff --git a/runtime/src/pdk.rs b/runtime/src/pdk.rs index c5d9eb5..ac2bb31 100644 --- a/runtime/src/pdk.rs +++ b/runtime/src/pdk.rs @@ -180,13 +180,12 @@ pub(crate) fn error_set( let offset = args!(input, 0, i64) as usize; if offset == 0 { - data.plugin_mut().clear_error(); + *data.last_error.borrow_mut() = None; return Ok(()); } - let plugin = data.plugin_mut(); - let s = plugin.memory.get_str(offset)?; - plugin.set_error(s); + let s = data.memory().get_str(offset)?; + data.set_error(s); Ok(()) } @@ -199,13 +198,16 @@ pub(crate) fn config_get( output: &mut [Val], ) -> Result<(), Error> { let data: &mut Internal = caller.data_mut(); - let plugin = data.plugin_mut(); let offset = args!(input, 0, i64) as usize; - let key = plugin.memory.get_str(offset)?; - let val = plugin.manifest.as_ref().config.get(key); - let mem = match val { - Some(f) => plugin.memory.alloc_bytes(f)?, + let key = data.memory().get_str(offset)?; + let val = data.memory().manifest.as_ref().config.get(key); + let ptr = val.map(|x| (x.len(), x.as_ptr())); + let mem = match ptr { + Some((len, ptr)) => { + let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; + data.memory_mut().alloc_bytes(bytes)? + } None => { output[0] = Val::I64(0); return Ok(()); @@ -224,14 +226,17 @@ pub(crate) fn var_get( output: &mut [Val], ) -> Result<(), Error> { let data: &mut Internal = caller.data_mut(); - let plugin = data.plugin_mut(); let offset = args!(input, 0, i64) as usize; - let key = plugin.memory.get_str(offset)?; - let val = plugin.vars.get(key); + let key = data.memory().get_str(offset)?; + let val = data.vars.get(key); + let ptr = val.map(|x| (x.len(), x.as_ptr())); - let mem = match val { - Some(f) => plugin.memory.alloc_bytes(f)?, + let mem = match ptr { + Some((len, ptr)) => { + let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; + data.memory_mut().alloc_bytes(bytes)? + } None => { output[0] = Val::I64(0); return Ok(()); @@ -251,10 +256,9 @@ pub(crate) fn var_set( _output: &mut [Val], ) -> Result<(), Error> { let data: &mut Internal = caller.data_mut(); - let plugin = data.plugin_mut(); let mut size = 0; - for v in plugin.vars.values() { + for v in data.vars.values() { size += v.len(); } @@ -266,18 +270,23 @@ pub(crate) fn var_set( } let key_offs = args!(input, 0, i64) as usize; - let key = plugin.memory.get_str(key_offs)?; + let key = { + let key = data.memory().get_str(key_offs)?; + let key_len = key.len(); + let key_ptr = key.as_ptr(); + unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(key_ptr, key_len)) } + }; // Remove if the value offset is 0 if voffset == 0 { - plugin.vars.remove(key); + data.vars.remove(key); return Ok(()); } - let value = plugin.memory.get(voffset)?; + let value = data.memory().get(voffset)?; // Insert the value from memory into the `vars` map - plugin.vars.insert(key.to_string(), value.to_vec()); + data.vars.insert(key.to_string(), value.to_vec()); Ok(()) } @@ -314,7 +323,7 @@ pub(crate) fn http_request( Ok(u) => u, Err(e) => return Err(Error::msg(format!("Invalid URL: {e:?}"))), }; - let allowed_hosts = &data.plugin().manifest.as_ref().allowed_hosts; + let allowed_hosts = &data.memory().manifest.as_ref().allowed_hosts; let host_str = url.host_str().unwrap_or_default(); let host_matches = if let Some(allowed_hosts) = allowed_hosts { allowed_hosts.iter().any(|url| { diff --git a/runtime/src/plugin.rs b/runtime/src/plugin.rs index 2e34e8a..f631e61 100644 --- a/runtime/src/plugin.rs +++ b/runtime/src/plugin.rs @@ -4,92 +4,41 @@ use crate::*; /// Plugin contains everything needed to execute a WASM function pub struct Plugin { - pub module: Module, + /// All modules that were provided to the linker + pub modules: BTreeMap, + + /// Used to define functions and create new instances pub linker: Linker, + + /// Instance provides the ability to call functions in a module pub instance: Instance, - pub last_error: std::cell::RefCell>, - pub memory: PluginMemory, - pub manifest: Manifest, - pub vars: BTreeMap>, + + /// Keep track of the number of times we're instantiated, this exists + /// to avoid issues with memory piling up since `Instance`s are only + /// actually cleaned up along with a `Store` + pub instantiations: usize, + + /// Handles interactions with WASM memory + pub memory: std::cell::UnsafeCell, + + /// The ID used to identify this plugin with the `Timer` pub timer_id: uuid::Uuid, + + /// A handle used to cancel execution of a plugin pub(crate) cancel_handle: sdk::ExtismCancelHandle, + + /// Runtime determines any initialization and cleanup functions needed + /// to run a module pub(crate) runtime: Option, } -pub struct Internal { - pub input_length: usize, - pub input: *const u8, - pub output_offset: usize, - pub output_length: usize, - pub plugin: *mut Plugin, - pub wasi: Option, - pub http_status: u16, -} - -pub struct Wasi { - pub ctx: wasmtime_wasi::WasiCtx, - #[cfg(feature = "nn")] - pub nn: wasmtime_wasi_nn::WasiNnCtx, - #[cfg(not(feature = "nn"))] - pub nn: (), -} - -impl Internal { - fn new(manifest: &Manifest, wasi: bool) -> Result { - let wasi = if wasi { - let auth = wasmtime_wasi::ambient_authority(); - let mut ctx = wasmtime_wasi::WasiCtxBuilder::new(); - for (k, v) in manifest.as_ref().config.iter() { - ctx = ctx.env(k, v)?; - } - - if let Some(a) = &manifest.as_ref().allowed_paths { - for (k, v) in a.iter() { - let d = wasmtime_wasi::Dir::open_ambient_dir(k, auth)?; - ctx = ctx.preopened_dir(d, v)?; - } - } - - #[cfg(feature = "nn")] - let nn = wasmtime_wasi_nn::WasiNnCtx::new()?; - - #[cfg(not(feature = "nn"))] - #[allow(clippy::let_unit_value)] - let nn = (); - - Some(Wasi { - ctx: ctx.build(), - nn, - }) - } else { - None - }; - - Ok(Internal { - input_length: 0, - output_offset: 0, - output_length: 0, - input: std::ptr::null(), - wasi, - plugin: std::ptr::null_mut(), - http_status: 0, - }) +impl InternalExt for Plugin { + fn memory(&self) -> &PluginMemory { + unsafe { &*self.memory.get() } } - pub fn plugin(&self) -> &Plugin { - unsafe { &*self.plugin } - } - - pub fn plugin_mut(&mut self) -> &mut Plugin { - unsafe { &mut *self.plugin } - } - - pub fn memory(&self) -> &PluginMemory { - &self.plugin().memory - } - - pub fn memory_mut(&mut self) -> &mut PluginMemory { - &mut self.plugin_mut().memory + fn memory_mut(&mut self) -> &mut PluginMemory { + self.memory.get_mut() } } @@ -102,6 +51,8 @@ impl Plugin { imports: impl IntoIterator, with_wasi: bool, ) -> Result { + // Create a new engine, if the `EXITSM_DEBUG` environment variable is set + // then we enable debug info let engine = Engine::new( Config::new() .epoch_interruption(true) @@ -110,18 +61,19 @@ impl Plugin { let mut imports = imports.into_iter(); let (manifest, modules) = Manifest::new(&engine, wasm.as_ref())?; let mut store = Store::new(&engine, Internal::new(&manifest, with_wasi)?); - store.epoch_deadline_callback(|_internal| Err(Error::msg("timeout"))); + // Create memory let memory = Memory::new( &mut store, MemoryType::new(4, manifest.as_ref().memory.max_pages), )?; - let mut memory = PluginMemory::new(store, memory); + let mut memory = PluginMemory::new(store, memory, manifest); let mut linker = Linker::new(&engine); linker.allow_shadowing(true); + // If wasi is enabled then add it to the linker if with_wasi { wasmtime_wasi::add_to_linker(&mut linker, |x: &mut Internal| { &mut x.wasi.as_mut().unwrap().ctx @@ -138,14 +90,14 @@ impl Plugin { (entry.0.as_str(), entry.1) }); + // Define PDK functions macro_rules! define_funcs { ($m:expr, { $($name:ident($($args:expr),*) $(-> $($r:expr),*)?);* $(;)?}) => { match $m { $( concat!("extism_", stringify!($name)) => { let t = FuncType::new([$($args),*], [$($($r),*)?]); - let f = Func::new(&mut memory.store, t, pdk::$name); - linker.define(&mut memory.store, EXPORT_MODULE_NAME, concat!("extism_", stringify!($name)), Extern::Func(f))?; + linker.func_new(EXPORT_MODULE_NAME, concat!("extism_", stringify!($name)), t, pdk::$name)?; continue } )* @@ -189,10 +141,9 @@ impl Plugin { for f in &mut imports { let name = f.name().to_string(); let ns = f.namespace().unwrap_or(EXPORT_MODULE_NAME); - let func = Func::new(&mut memory.store, f.ty().clone(), unsafe { + linker.func_new(ns, &name, f.ty().clone(), unsafe { &*std::sync::Arc::as_ptr(&f.f) - }); - linker.define(&mut memory.store, ns, &name, func)?; + })?; } } } @@ -201,21 +152,18 @@ impl Plugin { // Add modules to linker for (name, module) in modules.iter() { if name != main_name { - linker.module(&mut memory.store, name, module)?; - linker.alias_module(name, "env")?; + linker.module(&mut memory.store_mut(), name, module)?; } } - let instance = linker.instantiate(&mut memory.store, main)?; + let instance = linker.instantiate(&mut memory.store_mut(), main)?; let timer_id = uuid::Uuid::new_v4(); let mut plugin = Plugin { - module: main.clone(), + modules, linker, - memory, + memory: std::cell::UnsafeCell::new(memory), instance, - last_error: std::cell::RefCell::new(None), - manifest, - vars: BTreeMap::new(), + instantiations: 1, runtime: None, timer_id, cancel_handle: sdk::ExtismCancelHandle { @@ -223,6 +171,11 @@ impl Plugin { epoch_timer_tx: None, }, }; + + // Make sure `Internal::memory` is initialized + plugin.internal_mut().memory = plugin.memory.get(); + + // Then detect runtime before returning the new plugin plugin.detect_runtime(); Ok(plugin) } @@ -230,52 +183,70 @@ impl Plugin { /// Get a function by name pub fn get_func(&mut self, function: impl AsRef) -> Option { self.instance - .get_func(&mut self.memory.store, function.as_ref()) - } - - /// Set `last_error` field - pub fn set_error(&self, e: impl std::fmt::Debug) { - debug!("Set error: {:?}", e); - *self.last_error.borrow_mut() = Some(error_string(e)); + .get_func(&mut self.memory.get_mut().store_mut(), function.as_ref()) } + // A convenience method to set the plugin error and return a value pub fn error(&self, e: impl std::fmt::Debug, x: E) -> E { - self.set_error(e); + self.store().data().set_error(e); x } - /// Unset `last_error` field - pub fn clear_error(&self) { - *self.last_error.borrow_mut() = None; - } - /// Store input in memory and initialize `Internal` pointer pub fn set_input(&mut self, input: *const u8, mut len: usize) { if input.is_null() { len = 0; } - let ptr = self as *mut _; - let internal = self.memory.store.data_mut(); + let ptr = self.memory.get(); + let internal = self.internal_mut(); internal.input = input; internal.input_length = len; - internal.plugin = ptr; + internal.memory = ptr } + /// Dump memory using trace! logging pub fn dump_memory(&self) { - self.memory.dump(); + self.memory().dump(); } + /// Create a new instance from the same modules pub fn reinstantiate(&mut self) -> Result<(), Error> { + let (main_name, main) = self + .modules + .get("main") + .map(|x| ("main", x)) + .unwrap_or_else(|| { + let entry = self.modules.iter().last().unwrap(); + (entry.0.as_str(), entry.1) + }); + + // Avoid running into resource limits, after 5 instantiations reset the store. This will + // release any old `Instance` objects + if self.instantiations > 5 { + self.memory.get_mut().reinstantiate()?; + + // Get the `main` module, or the last one if `main` doesn't exist + for (name, module) in self.modules.iter() { + if name != main_name { + self.linker + .module(&mut self.memory.get_mut().store_mut(), name, module)?; + } + } + self.instantiations = 0; + } + let instance = self .linker - .instantiate(&mut self.memory.store, &self.module)?; + .instantiate(&mut self.memory.get_mut().store_mut(), &main)?; self.instance = instance; self.detect_runtime(); + self.instantiations += 1; Ok(()) } + /// Determine if wasi is enabled pub fn has_wasi(&self) -> bool { - self.memory.store.data().wasi.is_some() + self.memory().store().data().wasi.is_some() } fn detect_runtime(&mut self) { @@ -284,10 +255,13 @@ impl Plugin { // by calling the `hs_init` export if let Some(init) = self.get_func("hs_init") { if let Some(cleanup) = self.get_func("hs_exit") { - if init.typed::<(i32, i32), ()>(&self.memory.store).is_err() { + if init + .typed::<(i32, i32), ()>(&self.memory().store()) + .is_err() + { trace!( "hs_init function found with type {:?}", - init.ty(&self.memory.store) + init.ty(&self.memory().store()) ); } self.runtime = Some(Runtime::Haskell { init, cleanup }); @@ -299,19 +273,19 @@ impl Plugin { // initialize certain interfaces. if self.has_wasi() { if let Some(init) = self.get_func("__wasm_call_ctors") { - if init.typed::<(), ()>(&self.memory.store).is_err() { + if init.typed::<(), ()>(&self.memory().store()).is_err() { trace!( "__wasm_call_ctors function found with type {:?}", - init.ty(&self.memory.store) + init.ty(&self.memory().store()) ); return; } trace!("WASI runtime detected"); if let Some(cleanup) = self.get_func("__wasm_call_dtors") { - if cleanup.typed::<(), ()>(&self.memory.store).is_err() { + if cleanup.typed::<(), ()>(&self.memory().store()).is_err() { trace!( "__wasm_call_dtors function found with type {:?}", - cleanup.ty(&self.memory.store) + cleanup.ty(&self.memory().store()) ); return; } @@ -339,9 +313,9 @@ impl Plugin { match runtime { Runtime::Haskell { init, cleanup: _ } => { let mut results = - vec![Val::null(); init.ty(&self.memory.store).results().len()]; + vec![Val::null(); init.ty(&self.memory().store()).results().len()]; init.call( - &mut self.memory.store, + &mut self.memory.get_mut().store_mut(), &[Val::I32(0), Val::I32(0)], results.as_mut_slice(), )?; @@ -349,7 +323,7 @@ impl Plugin { } Runtime::Wasi { init, cleanup: _ } => { debug!("Calling __wasm_call_ctors"); - init.call(&mut self.memory.store, &[], &mut [])?; + init.call(&mut self.memory.get_mut().store_mut(), &[], &mut [])?; } } } @@ -367,7 +341,7 @@ impl Plugin { cleanup: Some(cleanup), } => { debug!("Calling __wasm_call_dtors"); - cleanup.call(&mut self.memory.store, &[], &mut [])?; + cleanup.call(&mut self.memory_mut().store_mut(), &[], &mut [])?; } Runtime::Wasi { init: _, @@ -377,8 +351,12 @@ impl Plugin { // by calling the `hs_exit` export Runtime::Haskell { init: _, cleanup } => { let mut results = - vec![Val::null(); cleanup.ty(&self.memory.store).results().len()]; - cleanup.call(&mut self.memory.store, &[], results.as_mut_slice())?; + vec![Val::null(); cleanup.ty(&self.memory().store()).results().len()]; + cleanup.call( + &mut self.memory_mut().store_mut(), + &[], + results.as_mut_slice(), + )?; debug!("Cleaned up Haskell language runtime"); } } @@ -387,18 +365,21 @@ impl Plugin { Ok(()) } + /// Start the timer for a Plugin - this is used for both timeouts + /// and cancellation pub(crate) fn start_timer( &mut self, tx: &std::sync::mpsc::SyncSender, ) -> Result<(), Error> { let duration = self + .memory() .manifest .as_ref() .timeout_ms .map(std::time::Duration::from_millis); self.cancel_handle.epoch_timer_tx = Some(tx.clone()); - self.memory.store.set_epoch_deadline(1); - let engine: Engine = self.memory.store.engine().clone(); + self.memory_mut().store_mut().set_epoch_deadline(1); + let engine: Engine = self.memory().store().engine().clone(); tx.send(TimerAction::Start { id: self.timer_id, duration, @@ -407,20 +388,13 @@ impl Plugin { Ok(()) } + /// Send TimerAction::Stop pub(crate) fn stop_timer(&mut self) -> Result<(), Error> { if let Some(tx) = &self.cancel_handle.epoch_timer_tx { tx.send(TimerAction::Stop { id: self.timer_id })?; } Ok(()) } - - pub fn cancel(&self) -> Result<(), Error> { - if let Some(tx) = &self.cancel_handle.epoch_timer_tx { - tx.send(TimerAction::Cancel { id: self.timer_id })?; - } - - Ok(()) - } } // Enumerates the supported PDK language runtimes diff --git a/runtime/src/plugin_ref.rs b/runtime/src/plugin_ref.rs index a13bae2..9d6478d 100644 --- a/runtime/src/plugin_ref.rs +++ b/runtime/src/plugin_ref.rs @@ -15,9 +15,18 @@ impl<'a> PluginRef<'a> { /// - Updates `input` pointer pub fn init(mut self, data: *const u8, data_len: usize) -> Self { trace!("PluginRef::init: {}", self.id,); - self.as_mut().memory.reset(); - self.as_mut().set_input(data, data_len); - + let plugin = self.as_mut(); + plugin.memory_mut().reset(); + if plugin.has_wasi() || plugin.runtime.is_some() { + if let Err(e) = plugin.reinstantiate() { + error!("Failed to reinstantiate: {e:?}"); + plugin + .internal() + .set_error(format!("Failed to reinstantiate: {e:?}")); + return self; + } + } + plugin.set_input(data, data_len); self } @@ -29,31 +38,19 @@ impl<'a> PluginRef<'a> { let epoch_timer_tx = ctx.epoch_timer_tx.clone(); - if !ctx.plugin_exists(plugin_id) { + let plugin = if let Some(plugin) = ctx.plugin(plugin_id) { + plugin + } else { error!("Plugin does not exist: {plugin_id}"); return ctx.error(format!("Plugin does not exist: {plugin_id}"), None); - } + }; if clear_error { trace!("Clearing context error"); ctx.error = None; - } - - // `unwrap` is okay here because we already checked with `ctx.plugin_exists` above - let plugin = ctx.plugin(plugin_id).unwrap(); - - { - let plugin = unsafe { &mut *plugin }; - if clear_error { - trace!("Clearing plugin error: {plugin_id}"); - plugin.clear_error(); - } - - if plugin.has_wasi() || plugin.runtime.is_some() { - if let Err(e) = plugin.reinstantiate() { - error!("Failed to reinstantiate: {e:?}"); - return plugin.error(format!("Failed to reinstantiate: {e:?}"), None); - } + trace!("Clearing plugin error: {plugin_id}"); + unsafe { + (&*plugin).internal().clear_error(); } } diff --git a/runtime/src/sdk.rs b/runtime/src/sdk.rs index 031fc54..7c27950 100644 --- a/runtime/src/sdk.rs +++ b/runtime/src/sdk.rs @@ -33,7 +33,7 @@ impl From for ExtismFunction { /// Host function signature pub type ExtismFunctionType = extern "C" fn( - plugin: *mut Plugin, + plugin: *mut Internal, inputs: *const ExtismVal, n_inputs: Size, outputs: *mut ExtismVal, @@ -93,28 +93,31 @@ pub unsafe extern "C" fn extism_context_free(ctx: *mut Context) { /// Returns a pointer to the memory of the currently running plugin /// NOTE: this should only be called from host functions. #[no_mangle] -pub unsafe extern "C" fn extism_current_plugin_memory(plugin: *mut Plugin) -> *mut u8 { +pub unsafe extern "C" fn extism_current_plugin_memory(plugin: *mut Internal) -> *mut u8 { if plugin.is_null() { return std::ptr::null_mut(); } let plugin = &mut *plugin; - plugin.memory.data_mut().as_mut_ptr() + plugin.memory_mut().data_mut().as_mut_ptr() } /// Allocate a memory block in the currently running plugin /// NOTE: this should only be called from host functions. #[no_mangle] -pub unsafe extern "C" fn extism_current_plugin_memory_alloc(plugin: *mut Plugin, n: Size) -> u64 { +pub unsafe extern "C" fn extism_current_plugin_memory_alloc(plugin: *mut Internal, n: Size) -> u64 { if plugin.is_null() { return 0; } let plugin = &mut *plugin; - let mem = match plugin.memory.alloc(n as usize) { + let mem = match plugin.memory_mut().alloc(n as usize) { Ok(x) => x, - Err(e) => return plugin.error(e, 0), + Err(e) => { + plugin.set_error(e); + return 0; + } }; mem.offset as u64 @@ -123,14 +126,17 @@ pub unsafe extern "C" fn extism_current_plugin_memory_alloc(plugin: *mut Plugin, /// Get the length of an allocated block /// NOTE: this should only be called from host functions. #[no_mangle] -pub unsafe extern "C" fn extism_current_plugin_memory_length(plugin: *mut Plugin, n: Size) -> Size { +pub unsafe extern "C" fn extism_current_plugin_memory_length( + plugin: *mut Internal, + n: Size, +) -> Size { if plugin.is_null() { return 0; } let plugin = &mut *plugin; - match plugin.memory.block_length(n as usize) { + match plugin.memory().block_length(n as usize) { Some(x) => x as Size, None => 0, } @@ -139,14 +145,13 @@ pub unsafe extern "C" fn extism_current_plugin_memory_length(plugin: *mut Plugin /// Free an allocated memory block /// NOTE: this should only be called from host functions. #[no_mangle] -pub unsafe extern "C" fn extism_current_plugin_memory_free(plugin: *mut Plugin, ptr: u64) { +pub unsafe extern "C" fn extism_current_plugin_memory_free(plugin: *mut Internal, ptr: u64) { if plugin.is_null() { return; } let plugin = &mut *plugin; - - plugin.memory.free(ptr as usize); + plugin.memory_mut().free(ptr as usize); } /// Create a new host function @@ -339,8 +344,7 @@ pub unsafe extern "C" fn extism_plugin_update( return false; } - ctx.plugins - .insert(index, std::cell::UnsafeCell::new(plugin)); + ctx.plugins.insert(index, plugin); debug!("Plugin updated: {index}"); true @@ -412,44 +416,49 @@ pub unsafe extern "C" fn extism_plugin_config( json_size: Size, ) -> bool { let ctx = &mut *ctx; - let mut plugin = match PluginRef::new(ctx, plugin, true) { + let mut plugin_ref = match PluginRef::new(ctx, plugin, true) { None => return false, Some(p) => p, }; - trace!( "Call to extism_plugin_config for {} with json pointer {:?}", - plugin.id, + plugin_ref.id, json ); + let plugin = plugin_ref.as_mut(); let data = std::slice::from_raw_parts(json, json_size as usize); let json: std::collections::BTreeMap> = match serde_json::from_slice(data) { Ok(x) => x, Err(e) => { - return plugin.as_mut().error(e, false); + return plugin.error(e, false); } }; - let plugin = plugin.as_mut(); + let wasi = &mut plugin.memory.get_mut().store_mut().data_mut().wasi; + if let Some(Wasi { ctx, .. }) = wasi { + for (k, v) in json.iter() { + match v { + Some(v) => { + let _ = ctx.push_env(&k, &v); + } + None => { + let _ = ctx.push_env(&k, ""); + } + } + } + } - let wasi = &mut plugin.memory.store.data_mut().wasi; - let config = &mut plugin.manifest.as_mut().config; + let config = &mut plugin.memory.get_mut().manifest.as_mut().config; for (k, v) in json.into_iter() { match v { Some(v) => { trace!("Config, adding {k}"); - if let Some(Wasi { ctx, .. }) = wasi { - let _ = ctx.push_env(&k, &v); - } config.insert(k, v); } None => { trace!("Config, removing {k}"); - if let Some(Wasi { ctx, .. }) = wasi { - let _ = ctx.push_env(&k, ""); - } config.remove(&k); } } @@ -505,38 +514,39 @@ pub unsafe extern "C" fn extism_plugin_call( None => return -1, Some(p) => p.init(data, data_len as usize), }; + let tx = plugin_ref.epoch_timer_tx.clone(); + let plugin = plugin_ref.as_mut(); + + if plugin.internal().last_error.borrow().is_some() { + return -1; + } // Find function let name = std::ffi::CStr::from_ptr(func_name); let name = match name.to_str() { Ok(name) => name, - Err(e) => return plugin_ref.as_ref().error(e, -1), + Err(e) => return plugin.error(e, -1), }; let is_start = name == "_start"; - let func = match plugin_ref.as_mut().get_func(name) { + let func = match plugin.get_func(name) { Some(x) => x, - None => { - return plugin_ref - .as_ref() - .error(format!("Function not found: {name}"), -1) - } + None => return plugin.error(format!("Function not found: {name}"), -1), }; // Start timer - let tx = plugin_ref.epoch_timer_tx.clone(); - if let Err(e) = plugin_ref.as_mut().start_timer(&tx) { - let id = plugin_ref.as_ref().timer_id; - return plugin_ref.as_ref().error( + if let Err(e) = plugin.start_timer(&tx) { + let id = plugin.timer_id; + return plugin.error( format!("Unable to start timeout manager for {id}: {e:?}"), -1, ); } // Check the number of results, reject functions with more than 1 result - let n_results = func.ty(&plugin_ref.as_ref().memory.store).results().len(); + let n_results = func.ty(plugin.store()).results().len(); if n_results > 1 { - return plugin_ref.as_ref().error( + return plugin.error( format!("Function {name} has {n_results} results, expected 0 or 1"), -1, ); @@ -544,10 +554,8 @@ pub unsafe extern "C" fn extism_plugin_call( // Initialize runtime if !is_start { - if let Err(e) = plugin_ref.as_mut().initialize_runtime() { - return plugin_ref - .as_ref() - .error(format!("Failed to initialize runtime: {e:?}"), -1); + if let Err(e) = plugin.initialize_runtime() { + return plugin.error(format!("Failed to initialize runtime: {e:?}"), -1); } } @@ -555,27 +563,21 @@ pub unsafe extern "C" fn extism_plugin_call( // Call the function let mut results = vec![wasmtime::Val::null(); n_results]; - let res = func.call( - &mut plugin_ref.as_mut().memory.store, - &[], - results.as_mut_slice(), - ); + let res = func.call(&mut plugin.store_mut(), &[], results.as_mut_slice()); - plugin_ref.as_ref().dump_memory(); + plugin.dump_memory(); // Cleanup runtime if !is_start { - if let Err(e) = plugin_ref.as_mut().cleanup_runtime() { - return plugin_ref - .as_ref() - .error(format!("Failed to cleanup runtime: {e:?}"), -1); + if let Err(e) = plugin.cleanup_runtime() { + return plugin.error(format!("Failed to cleanup runtime: {e:?}"), -1); } } // Stop timer - if let Err(e) = plugin_ref.as_mut().stop_timer() { - let id = plugin_ref.as_ref().timer_id; - return plugin_ref.as_ref().error( + if let Err(e) = plugin.stop_timer() { + let id = plugin.timer_id; + return plugin.error( format!("Failed to stop timeout manager for {id}: {e:?}"), -1, ); @@ -584,7 +586,6 @@ pub unsafe extern "C" fn extism_plugin_call( match res { Ok(()) => (), Err(e) => { - let plugin = plugin_ref.as_ref(); if let Some(exit) = e.downcast_ref::() { trace!("WASI return code: {}", exit.0); if exit.0 != 0 { @@ -634,12 +635,13 @@ pub unsafe extern "C" fn extism_error(ctx: *mut Context, plugin: PluginIndex) -> return get_context_error(ctx); } - let plugin = match PluginRef::new(ctx, plugin, false) { + let plugin_ref = match PluginRef::new(ctx, plugin, false) { None => return std::ptr::null(), Some(p) => p, }; + let plugin = plugin_ref.as_ref(); - let err = plugin.as_ref().last_error.borrow(); + let err = plugin.internal().last_error.borrow(); match err.as_ref() { Some(e) => e.as_ptr() as *const _, None => { @@ -658,12 +660,13 @@ pub unsafe extern "C" fn extism_plugin_output_length( trace!("Call to extism_plugin_output_length for plugin {plugin}"); let ctx = &mut *ctx; - let plugin = match PluginRef::new(ctx, plugin, true) { + let plugin_ref = match PluginRef::new(ctx, plugin, true) { None => return 0, Some(p) => p, }; + let plugin = plugin_ref.as_ref(); - let len = plugin.as_ref().memory.store.data().output_length as Size; + let len = plugin.internal().output_length as Size; trace!("Output length: {len}"); len } @@ -677,16 +680,18 @@ pub unsafe extern "C" fn extism_plugin_output_data( trace!("Call to extism_plugin_output_data for plugin {plugin}"); let ctx = &mut *ctx; - let plugin = match PluginRef::new(ctx, plugin, true) { + let plugin_ref = match PluginRef::new(ctx, plugin, true) { None => return std::ptr::null(), Some(p) => p, }; - let data = plugin.as_ref().memory.store.data(); - + let plugin = plugin_ref.as_ref(); + let internal = plugin.internal(); plugin - .as_ref() - .memory - .ptr(MemoryBlock::new(data.output_offset, data.output_length)) + .memory() + .ptr(MemoryBlock::new( + internal.output_offset, + internal.output_length, + )) .map(|x| x as *const _) .unwrap_or(std::ptr::null()) } diff --git a/rust/src/context.rs b/rust/src/context.rs index e1cd2a7..4198ccc 100644 --- a/rust/src/context.rs +++ b/rust/src/context.rs @@ -1,5 +1,6 @@ use crate::*; +#[derive(Clone)] pub struct Context(pub(crate) std::sync::Arc>); impl Default for Context { @@ -8,6 +9,9 @@ impl Default for Context { } } +unsafe impl Sync for Context {} +unsafe impl Send for Context {} + impl Context { /// Create a new context pub fn new() -> Context { diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c4d561b..fb65f90 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,6 +1,6 @@ pub use extism_manifest::{self as manifest, Manifest}; pub use extism_runtime::{ - sdk as bindings, Function, MemoryBlock, Plugin as CurrentPlugin, UserData, Val, ValType, + sdk as bindings, Function, Internal as CurrentPlugin, MemoryBlock, UserData, Val, ValType, }; mod context; @@ -86,8 +86,7 @@ mod tests { ) .with_namespace("test"); - let functions = [&f, &g]; - let mut plugin = Plugin::new(&context, WASM, functions, true).unwrap(); + let mut plugin = Plugin::new(&context, WASM, [f, g], true).unwrap(); println!("register loaded plugin: {:?}", wasm_start.elapsed()); let repeat = 1182; @@ -175,7 +174,7 @@ mod tests { } #[test] - fn test_threads() { + fn test_context_threads() { use std::io::Write; std::thread::spawn(|| { let context = Context::new(); @@ -186,7 +185,7 @@ mod tests { None, hello_world, ); - let mut plugin = Plugin::new(&context, WASM, [&f], true).unwrap(); + let mut plugin = Plugin::new(&context, WASM, [f], true).unwrap(); let output = plugin.call("count_vowels", "this is a test").unwrap(); std::io::stdout().write_all(output).unwrap(); }); @@ -199,22 +198,63 @@ mod tests { hello_world, ); - let g = f.clone(); - std::thread::spawn(move || { - let context = Context::new(); - let mut plugin = PluginBuilder::new_with_module(WASM) - .with_function(&g) - .with_wasi(true) - .build(&context) - .unwrap(); - let output = plugin.call("count_vowels", "this is a test aaa").unwrap(); - std::io::stdout().write_all(output).unwrap(); - }); - + // One context shared between two threads let context = Context::new(); - let mut plugin = Plugin::new(&context, WASM, [&f], true).unwrap(); - let output = plugin.call("count_vowels", "abc123").unwrap(); - std::io::stdout().write_all(output).unwrap(); + let mut threads = vec![]; + for _ in 0..3 { + let ctx = context.clone(); + let g = f.clone(); + let a = std::thread::spawn(move || { + let mut plugin = PluginBuilder::new_with_module(WASM) + .with_function(g) + .with_wasi(true) + .build(Some(&ctx)) + .unwrap(); + for _ in 0..10 { + let output = plugin.call("count_vowels", "this is a test aaa").unwrap(); + assert_eq!(b"{\"count\": 7}", output); + } + }); + threads.push(a); + } + for thread in threads { + thread.join().unwrap(); + } + } + + #[test] + fn test_plugin_threads() { + let f = Function::new( + "hello_world", + [ValType::I64], + [ValType::I64], + None, + hello_world, + ); + + let p = std::sync::Arc::new(std::sync::Mutex::new( + PluginBuilder::new_with_module(WASM) + .with_function(f) + .with_wasi(true) + .build(None) + .unwrap(), + )); + + let mut threads = vec![]; + for _ in 0..3 { + let plugin = p.clone(); + let a = std::thread::spawn(move || { + let mut plugin = plugin.lock().unwrap(); + for _ in 0..10 { + let output = plugin.call("count_vowels", "this is a test aaa").unwrap(); + assert_eq!(b"{\"count\": 7}", output); + } + }); + threads.push(a); + } + for thread in threads { + thread.join().unwrap(); + } } #[test] @@ -228,7 +268,7 @@ mod tests { ); let context = Context::new(); - let mut plugin = Plugin::new(&context, WASM_LOOP, [&f], true).unwrap(); + let mut plugin = Plugin::new(&context, WASM_LOOP, [f], true).unwrap(); let handle = plugin.cancel_handle(); std::thread::spawn(move || { diff --git a/rust/src/plugin.rs b/rust/src/plugin.rs index f758fdf..220a3f2 100644 --- a/rust/src/plugin.rs +++ b/rust/src/plugin.rs @@ -1,9 +1,36 @@ use crate::*; use std::collections::BTreeMap; +enum RefOrOwned<'a, T> { + Ref(&'a T), + Owned(T), +} + pub struct Plugin<'a> { id: extism_runtime::PluginIndex, - context: &'a Context, + context: RefOrOwned<'a, Context>, + functions: Vec, +} + +impl<'a, T> From<&'a T> for RefOrOwned<'a, T> { + fn from(value: &'a T) -> Self { + RefOrOwned::Ref(value) + } +} + +impl<'a, T> From for RefOrOwned<'a, T> { + fn from(value: T) -> Self { + RefOrOwned::Owned(value) + } +} + +impl<'a, T> AsRef for RefOrOwned<'a, T> { + fn as_ref(&self) -> &T { + match self { + RefOrOwned::Ref(x) => x, + RefOrOwned::Owned(x) => x, + } + } } pub struct CancelHandle(pub(crate) *const extism_runtime::sdk::ExtismCancelHandle); @@ -23,32 +50,44 @@ impl<'a> Plugin<'a> { /// # Safety /// This function does not check to ensure the provided ID is valid pub unsafe fn from_id(id: i32, context: &'a Context) -> Plugin<'a> { - Plugin { id, context } + let context = RefOrOwned::Ref(context); + Plugin { + id, + context, + functions: vec![], + } + } + + pub fn context(&self) -> &Context { + match &self.context { + RefOrOwned::Ref(x) => x, + RefOrOwned::Owned(x) => x, + } } pub fn as_i32(&self) -> i32 { self.id } - /// Create a new plugin from the given manifest - pub fn new_with_manifest( - ctx: &'a Context, + /// Create a new plugin from the given manifest in its own context + pub fn create_with_manifest( manifest: &Manifest, - functions: impl IntoIterator, + functions: impl IntoIterator, wasi: bool, ) -> Result, Error> { let data = serde_json::to_vec(manifest)?; - Self::new(ctx, data, functions, wasi) + Self::create(data, functions, wasi) } - /// Create a new plugin from a WASM module - pub fn new( - ctx: &'a Context, + /// Create a new plugin from a WASM module in its own context + pub fn create( data: impl AsRef<[u8]>, - functions: impl IntoIterator, + functions: impl IntoIterator, wasi: bool, - ) -> Result { - let plugin = ctx.lock().new_plugin(data, functions, wasi); + ) -> Result, Error> { + let ctx = Context::new(); + let functions = functions.into_iter().collect(); + let plugin = ctx.lock().new_plugin(data, &functions, wasi); if plugin < 0 { let err = unsafe { bindings::extism_error(&mut *ctx.lock(), -1) }; @@ -59,7 +98,43 @@ impl<'a> Plugin<'a> { Ok(Plugin { id: plugin, - context: ctx, + context: ctx.into(), + functions, + }) + } + + /// Create a new plugin from the given manifest + pub fn new_with_manifest( + ctx: &'a Context, + manifest: &Manifest, + functions: impl IntoIterator, + wasi: bool, + ) -> Result, Error> { + let data = serde_json::to_vec(manifest)?; + Self::new(ctx, data, functions, wasi) + } + + /// Create a new plugin from a WASM module + pub fn new( + ctx: &'a Context, + data: impl AsRef<[u8]>, + functions: impl IntoIterator, + wasi: bool, + ) -> Result, Error> { + let functions = functions.into_iter().collect(); + let plugin = ctx.lock().new_plugin(data, &functions, wasi); + + if plugin < 0 { + let err = unsafe { bindings::extism_error(&mut *ctx.lock(), -1) }; + let buf = unsafe { std::ffi::CStr::from_ptr(err) }; + let buf = buf.to_str().unwrap(); + return Err(Error::msg(buf)); + } + + Ok(Plugin { + id: plugin, + context: ctx.into(), + functions, }) } @@ -67,7 +142,7 @@ impl<'a> Plugin<'a> { pub fn update_with_manifest( &mut self, manifest: &Manifest, - functions: impl IntoIterator, + functions: impl IntoIterator, wasi: bool, ) -> Result<(), Error> { let data = serde_json::to_vec(manifest)?; @@ -78,11 +153,13 @@ impl<'a> Plugin<'a> { pub fn update( &mut self, data: impl AsRef<[u8]>, - functions: impl IntoIterator, + functions: impl IntoIterator, wasi: bool, ) -> Result<(), Error> { - let functions = functions - .into_iter() + self.functions = functions.into_iter().collect(); + let functions = self + .functions + .iter() .map(|x| bindings::ExtismFunction::from(x.clone())); let mut functions = functions .into_iter() @@ -90,7 +167,7 @@ impl<'a> Plugin<'a> { .collect::>(); let b = unsafe { bindings::extism_plugin_update( - &mut *self.context.lock(), + &mut *self.context.as_ref().lock(), self.id, data.as_ref().as_ptr(), data.as_ref().len() as u64, @@ -103,7 +180,7 @@ impl<'a> Plugin<'a> { return Ok(()); } - let err = unsafe { bindings::extism_error(&mut *self.context.lock(), -1) }; + let err = unsafe { bindings::extism_error(&mut *self.context.as_ref().lock(), -1) }; if !err.is_null() { let s = unsafe { std::ffi::CStr::from_ptr(err) }; return Err(Error::msg(s.to_str().unwrap())); @@ -117,7 +194,7 @@ impl<'a> Plugin<'a> { let encoded = serde_json::to_vec(config)?; unsafe { bindings::extism_plugin_config( - &mut *self.context.lock(), + &mut *self.context.as_ref().lock(), self.id, encoded.as_ptr() as *const _, encoded.len() as u64, @@ -137,7 +214,7 @@ impl<'a> Plugin<'a> { let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); unsafe { bindings::extism_plugin_function_exists( - &mut *self.context.lock(), + &mut *self.context.as_ref().lock(), self.id, name.as_ptr() as *const _, ) @@ -145,18 +222,27 @@ impl<'a> Plugin<'a> { } pub fn cancel_handle(&self) -> CancelHandle { - let ptr = - unsafe { bindings::extism_plugin_cancel_handle(&mut *self.context.lock(), self.id) }; + let ptr = unsafe { + bindings::extism_plugin_cancel_handle(&mut *self.context.as_ref().lock(), self.id) + }; CancelHandle(ptr) } - /// Call a function with the given input - pub fn call(&mut self, name: impl AsRef, input: impl AsRef<[u8]>) -> Result<&[u8], Error> { + /// Call a function with the given input and call a callback with the output, this should be preferred when + /// a single plugin may be acessed from multiple threads because the lock on the plugin is held during the + /// callback, ensuring the output value is protected from modification. + pub fn call_map Result>( + &mut self, + name: impl AsRef, + input: impl AsRef<[u8]>, + f: F, + ) -> Result { + let context = &mut *self.context.as_ref().lock(); let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name"); let rc = unsafe { bindings::extism_plugin_call( - &mut *self.context.lock(), + context, self.id, name.as_ptr() as *const _, input.as_ref().as_ptr() as *const _, @@ -165,7 +251,7 @@ impl<'a> Plugin<'a> { }; if rc != 0 { - let err = unsafe { bindings::extism_error(&mut *self.context.lock(), self.id) }; + let err = unsafe { bindings::extism_error(context, self.id) }; if !err.is_null() { let s = unsafe { std::ffi::CStr::from_ptr(err) }; return Err(Error::msg(s.to_str().unwrap())); @@ -174,17 +260,25 @@ impl<'a> Plugin<'a> { return Err(Error::msg("extism_call failed")); } - let out_len = - unsafe { bindings::extism_plugin_output_length(&mut *self.context.lock(), self.id) }; + let out_len = unsafe { bindings::extism_plugin_output_length(context, self.id) }; unsafe { - let ptr = bindings::extism_plugin_output_data(&mut *self.context.lock(), self.id); - Ok(std::slice::from_raw_parts(ptr, out_len as usize)) + let ptr = bindings::extism_plugin_output_data(context, self.id); + f(std::slice::from_raw_parts(ptr, out_len as usize)) } } + + /// Call a function with the given input + pub fn call( + &mut self, + name: impl AsRef, + input: impl AsRef<[u8]>, + ) -> Result<&'a [u8], Error> { + self.call_map(name, input, |x| Ok(x)) + } } impl<'a> Drop for Plugin<'a> { fn drop(&mut self) { - unsafe { bindings::extism_plugin_free(&mut *self.context.lock(), self.id) } + unsafe { bindings::extism_plugin_free(&mut *self.context.as_ref().lock(), self.id) } } } diff --git a/rust/src/plugin_builder.rs b/rust/src/plugin_builder.rs index 3bf9bf9..76378c8 100644 --- a/rust/src/plugin_builder.rs +++ b/rust/src/plugin_builder.rs @@ -6,13 +6,13 @@ enum Source { } /// PluginBuilder is used to configure and create `Plugin` instances -pub struct PluginBuilder<'a> { +pub struct PluginBuilder { source: Source, wasi: bool, - functions: Vec<&'a Function>, + functions: Vec, } -impl<'a> PluginBuilder<'a> { +impl PluginBuilder { /// Create a new `PluginBuilder` with the given WebAssembly module pub fn new_with_module(data: impl Into>) -> Self { PluginBuilder { @@ -38,23 +38,29 @@ impl<'a> PluginBuilder<'a> { } /// Add a single host function - pub fn with_function(mut self, f: &'a Function) -> Self { + pub fn with_function(mut self, f: Function) -> Self { self.functions.push(f); self } /// Add multiple host functions - pub fn with_functions(mut self, f: impl IntoIterator) -> Self { + pub fn with_functions(mut self, f: impl IntoIterator) -> Self { self.functions.extend(f); self } - pub fn build(self, context: &'a Context) -> Result, Error> { - match self.source { - Source::Manifest(m) => { - Plugin::new_with_manifest(context, &m, self.functions, self.wasi) - } - Source::Data(d) => Plugin::new(context, d, self.functions, self.wasi), + pub fn build<'a>(self, context: Option<&'a Context>) -> Result, Error> { + match context { + Some(context) => match self.source { + Source::Manifest(m) => { + Plugin::new_with_manifest(context, &m, self.functions, self.wasi) + } + Source::Data(d) => Plugin::new(context, d, self.functions, self.wasi), + }, + None => match self.source { + Source::Manifest(m) => Plugin::create_with_manifest(&m, self.functions, self.wasi), + Source::Data(d) => Plugin::create(d, self.functions, self.wasi), + }, } } } diff --git a/zig/src/plugin.zig b/zig/src/plugin.zig index 4e05025..636b54a 100644 --- a/zig/src/plugin.zig +++ b/zig/src/plugin.zig @@ -9,13 +9,14 @@ const utils = @import("utils.zig"); const Self = @This(); ctx: *Context, +owns_context: bool, id: i32, // We have to use this until ziglang/zig#2647 is resolved. error_info: ?[]const u8, /// Create a new plugin from a WASM module -pub fn init(allocator: std.mem.Allocator, ctx: *Context, data: []const u8, functions: []Function, wasi: bool) !Self { +pub fn init(allocator: std.mem.Allocator, ctx: *Context, data: []const u8, functions: []const Function, wasi: bool) !Self { ctx.mutex.lock(); defer ctx.mutex.unlock(); var plugin: i32 = -1; @@ -45,20 +46,39 @@ pub fn init(allocator: std.mem.Allocator, ctx: *Context, data: []const u8, funct .id = plugin, .ctx = ctx, .error_info = null, + .owns_context = false, }; } /// Create a new plugin from the given manifest -pub fn initFromManifest(allocator: std.mem.Allocator, ctx: *Context, manifest: Manifest, functions: []Function, wasi: bool) !Self { +pub fn initFromManifest(allocator: std.mem.Allocator, ctx: *Context, manifest: Manifest, functions: []const Function, wasi: bool) !Self { const json = try utils.stringifyAlloc(allocator, manifest); defer allocator.free(json); return init(allocator, ctx, json, functions, wasi); } +/// Create a new plugin from a WASM module in its own context +pub fn create(allocator: std.mem.Allocator, data: []const u8, functions: []const Function, wasi: bool) !Self { + const ctx = Context.init(); + var plugin = init(allocator, ctx, data, functions, wasi); + plugin.owns_context = true; + return plugin; +} + +/// Create a new plugin from the given manifest in its own context +pub fn createFromManifest(allocator: std.mem.Allocator, manifest: Manifest, functions: []const Function, wasi: bool) !Self { + const json = try utils.stringifyAlloc(allocator, manifest); + defer allocator.free(json); + return create(allocator, json, functions, wasi); +} + pub fn deinit(self: *Self) void { self.ctx.mutex.lock(); defer self.ctx.mutex.unlock(); c.extism_plugin_free(self.ctx.ctx, self.id); + if (self.owns_context) { + self.ctx.deinit(); + } } pub fn cancelHandle(self: *Self) CancelHandle {