diff --git a/runtime/src/current_plugin.rs b/runtime/src/current_plugin.rs index a2fb0be..ea80018 100644 --- a/runtime/src/current_plugin.rs +++ b/runtime/src/current_plugin.rs @@ -187,7 +187,7 @@ impl CurrentPlugin { anyhow::bail!("{} unable to locate extism memory", self.id) } - pub fn host_context(&mut self) -> Result { + pub fn host_context(&mut self) -> Result<&mut T, Error> { let (linker, store) = self.linker_and_store(); let Some(Extern::Global(xs)) = linker.get(&mut *store, EXTISM_ENV_MODULE, "extism_context") else { @@ -198,9 +198,15 @@ impl CurrentPlugin { anyhow::bail!("expected extism_context to be an externref value",) }; - match xs.data(&mut *store)?.downcast_ref::().cloned() { - Some(xs) => Ok(xs.clone()), - None => anyhow::bail!("could not downcast extism_context",), + match xs + .data_mut(&mut *store)? + .downcast_mut::>() + { + Some(xs) => match xs.downcast_mut::() { + Some(xs) => Ok(xs), + None => anyhow::bail!("could not downcast extism_context inner value"), + }, + None => anyhow::bail!("could not downcast extism_context"), } } diff --git a/runtime/src/plugin.rs b/runtime/src/plugin.rs index 1e423a3..8209284 100644 --- a/runtime/src/plugin.rs +++ b/runtime/src/plugin.rs @@ -85,6 +85,8 @@ pub struct Plugin { pub(crate) error_msg: Option>, pub(crate) fuel: Option, + + pub(crate) host_context: Rooted, } unsafe impl Send for Plugin {} @@ -216,13 +218,21 @@ fn add_module( Ok(()) } +#[allow(clippy::type_complexity)] fn relink( engine: &Engine, mut store: &mut Store, imports: &[Function], modules: &BTreeMap, with_wasi: bool, -) -> Result<(InstancePre, Linker), Error> { +) -> Result< + ( + InstancePre, + Linker, + Rooted, + ), + Error, +> { let mut linker = Linker::new(engine); linker.allow_shadowing(true); @@ -282,9 +292,12 @@ fn relink( )?; } + let inner: Box = Box::new(()); + let host_context = ExternRef::new(store, inner)?; + let main = &modules[MAIN_KEY]; let instance_pre = linker.instantiate_pre(main)?; - Ok((instance_pre, linker)) + Ok((instance_pre, linker, host_context)) } impl Plugin { @@ -366,8 +379,8 @@ impl Plugin { } let imports: Vec = imports.into_iter().collect(); - let (instance_pre, linker) = relink(&engine, &mut store, &imports, &modules, with_wasi)?; - + let (instance_pre, linker, host_context) = + relink(&engine, &mut store, &imports, &modules, with_wasi)?; let timer_tx = Timer::tx(); let mut plugin = Plugin { modules, @@ -386,6 +399,7 @@ impl Plugin { _functions: imports, error_msg: None, fuel, + host_context, }; plugin.current_plugin_mut().store = &mut plugin.store; @@ -423,7 +437,7 @@ impl Plugin { self.store.set_fuel(fuel)?; } - let (instance_pre, linker) = relink( + let (instance_pre, linker, host_context) = relink( &engine, &mut self.store, &self._functions, @@ -432,6 +446,7 @@ impl Plugin { )?; self.linker = linker; self.instance_pre = instance_pre; + self.host_context = host_context; let store = &mut self.store as *mut _; let linker = &mut self.linker as *mut _; let current_plugin = self.current_plugin_mut(); @@ -725,12 +740,12 @@ impl Plugin { // Implements the build of the `call` function, `raw_call` is also used in the SDK // code - pub(crate) fn raw_call( + pub(crate) fn raw_call( &mut self, lock: &mut std::sync::MutexGuard>, name: impl AsRef, input: impl AsRef<[u8]>, - host_context: Option>, + host_context: Option, ) -> Result { let name = name.as_ref(); let input = input.as_ref(); @@ -744,7 +759,22 @@ impl Plugin { self.instantiate(lock).map_err(|e| (e, -1))?; - self.set_input(input.as_ptr(), input.len(), host_context) + // Set host context + let r = if let Some(host_context) = host_context { + let inner = self + .host_context + .data_mut(&mut self.store) + .map_err(|x| (x, -1))?; + if let Some(inner) = inner.downcast_mut::>() { + let x: Box = Box::new(host_context); + *inner = x; + } + Some(self.host_context) + } else { + None + }; + + self.set_input(input.as_ptr(), input.len(), r) .map_err(|x| (x, -1))?; let func = match self.get_func(lock, name) { @@ -784,6 +814,14 @@ impl Plugin { let mut results = vec![wasmtime::Val::I32(0); n_results]; let mut res = func.call(self.store_mut(), &[], results.as_mut_slice()); + // Reset host context + if let Ok(inner) = self.host_context.data_mut(&mut self.store) { + if let Some(inner) = inner.downcast_mut::>() { + let x: Box = Box::new(()); + *inner = x; + } + } + // Stop timer self.store .epoch_deadline_callback(|_| Ok(UpdateDeadline::Continue(1))); @@ -929,7 +967,7 @@ impl Plugin { let lock = self.instance.clone(); let mut lock = lock.lock().unwrap(); let data = input.to_bytes()?; - self.raw_call(&mut lock, name, data, None) + self.raw_call(&mut lock, name, data, None::<()>) .map_err(|e| e.0) .and_then(move |rc| { if rc != 0 { @@ -954,8 +992,7 @@ impl Plugin { let lock = self.instance.clone(); let mut lock = lock.lock().unwrap(); let data = input.to_bytes()?; - let ctx = ExternRef::new(&mut self.store, host_context)?; - self.raw_call(&mut lock, name, data, Some(ctx)) + self.raw_call(&mut lock, name, data, Some(host_context)) .map_err(|e| e.0) .and_then(move |_| self.output()) } @@ -974,7 +1011,7 @@ impl Plugin { let lock = self.instance.clone(); let mut lock = lock.lock().unwrap(); let data = input.to_bytes().map_err(|e| (e, -1))?; - self.raw_call(&mut lock, name, data, None) + self.raw_call(&mut lock, name, data, None::<()>) .and_then(move |_| self.output().map_err(|e| (e, -1))) } diff --git a/runtime/src/sdk.rs b/runtime/src/sdk.rs index a2f983c..26b5c26 100644 --- a/runtime/src/sdk.rs +++ b/runtime/src/sdk.rs @@ -96,7 +96,7 @@ pub unsafe extern "C" fn extism_current_plugin_host_context( let plugin = &mut *plugin; if let Ok(CVoidContainer(ptr)) = plugin.host_context::() { - ptr + *ptr } else { std::ptr::null_mut() } @@ -565,6 +565,13 @@ pub unsafe extern "C" fn extism_plugin_call_with_host_context( let lock = plugin.instance.clone(); let mut lock = lock.lock().unwrap(); + if let Err(e) = plugin.reset_store(&mut lock) { + error!( + plugin = plugin.id.to_string(), + "call to Plugin::reset_store failed: {e:?}" + ); + } + plugin.error_msg = None; // Get function name @@ -580,11 +587,12 @@ pub unsafe extern "C" fn extism_plugin_call_with_host_context( name ); let input = std::slice::from_raw_parts(data, data_len as usize); - let r = match ExternRef::new(&mut plugin.store, CVoidContainer(host_context)) { - Err(e) => return plugin.return_error(&mut lock, e, -1), - Ok(x) => x, + let r = if host_context.is_null() { + None + } else { + Some(CVoidContainer(host_context)) }; - let res = plugin.raw_call(&mut lock, name, input, Some(r)); + let res = plugin.raw_call(&mut lock, name, input, r); match res { Err((e, rc)) => plugin.return_error(&mut lock, e, rc), Ok(x) => x, diff --git a/runtime/src/tests/runtime.rs b/runtime/src/tests/runtime.rs index a735557..1df064c 100644 --- a/runtime/src/tests/runtime.rs +++ b/runtime/src/tests/runtime.rs @@ -334,8 +334,10 @@ fn test_multiple_instantiations() { #[test] fn test_globals() { let mut plugin = Plugin::new(WASM_GLOBALS, [], true).unwrap(); - for i in 0..100000 { - let Json(count) = plugin.call::<_, Json>("globals", "").unwrap(); + for i in 0..100001 { + let Json(count) = plugin + .call_with_host_context::<_, Json, _>("globals", "", ()) + .unwrap(); assert_eq!(count.count, i); } } @@ -366,7 +368,7 @@ fn test_call_with_host_context() { [PTR], UserData::default(), |current_plugin, _val, ret, _user_data: UserData<()>| { - let foo = current_plugin.host_context::()?; + let foo = current_plugin.host_context::()?.clone(); let hnd = current_plugin.memory_new(foo.message)?; ret[0] = current_plugin.memory_to_val(hnd); Ok(())