feat: Add Host Functions support for Java SDK (#248)

This commit is contained in:
Etienne ANNE
2023-03-02 22:46:13 +01:00
committed by GitHub
parent 581e9cea99
commit 3e69ceeede
13 changed files with 444 additions and 57 deletions

5
.gitignore vendored
View File

@@ -36,3 +36,8 @@ vendor
zig/zig-*
zig/example-out/
zig/*.log
java/*.iml
java/*.log
java/.idea
java/.DS_Store

View File

@@ -28,10 +28,11 @@ public class Context implements AutoCloseable {
*
* @param manifest The manifest for the plugin
* @param withWASI Set to true to enable WASI
* @param functions List of Host functions
* @return the plugin instance
*/
public Plugin newPlugin(Manifest manifest, boolean withWASI) {
return new Plugin(this, manifest, withWASI);
public Plugin newPlugin(Manifest manifest, boolean withWASI, HostFunction[] functions) {
return new Plugin(this, manifest, withWASI, functions);
}
/**

View File

@@ -42,7 +42,7 @@ public class Extism {
*/
public static String invokeFunction(Manifest manifest, String function, String input) throws ExtismException {
try (var ctx = new Context()) {
try (var plugin = ctx.newPlugin(manifest, false)) {
try (var plugin = ctx.newPlugin(manifest, false, null)) {
return plugin.call(function, input);
}
}

View File

@@ -0,0 +1,78 @@
package org.extism.sdk;
import com.sun.jna.Pointer;
import java.nio.charset.StandardCharsets;
public class ExtismCurrentPlugin {
public Pointer pointer;
public ExtismCurrentPlugin(Pointer pointer) {
this.pointer = pointer;
}
public Pointer memory() {
return LibExtism.INSTANCE.extism_current_plugin_memory(this.pointer);
}
public int alloc(int n) {
return LibExtism.INSTANCE.extism_current_plugin_memory_alloc(this.pointer, n);
}
public void free(long offset) {
LibExtism.INSTANCE.extism_current_plugin_memory_free(this.pointer, offset);
}
public long memoryLength(long offset) {
return LibExtism.INSTANCE.extism_current_plugin_memory_length(this.pointer, offset);
}
/**
* Return a string from a host function
* @param output - The output to set
* @param s - The string to return
*/
public void returnString(LibExtism.ExtismVal output, String s) {
returnBytes(output, s.getBytes(StandardCharsets.UTF_8));
}
/**
* Return bytes from a host function
* @param output - The output to set
* @param b - The buffer to return
*/
public void returnBytes(LibExtism.ExtismVal output, byte[] b) {
int offs = this.alloc(b.length);
Pointer ptr = this.memory();
ptr.write(offs, b, 0, b.length);
output.v.i64 = offs;
}
/**
* Get bytes from host function parameter
* @param input - The input to read
*/
public byte[] inputBytes(LibExtism.ExtismVal input) {
switch (input.t) {
case 0:
return this.memory()
.getByteArray(input.v.i32,
LibExtism.INSTANCE.extism_current_plugin_memory_length(this.pointer, input.v.i32));
case 1:
return this.memory()
.getByteArray(input.v.i64,
LibExtism.INSTANCE.extism_current_plugin_memory_length(this.pointer, input.v.i64));
default:
throw new ExtismException("inputBytes error: ExtismValType " + LibExtism.ExtismValType.values()[input.t] + " not implemtented");
}
}
/**
* Get string from host function parameter
* @param input - The input to read
*/
public String inputString(LibExtism.ExtismVal input) {
return new String(this.inputBytes(input));
}
}

View File

@@ -0,0 +1,12 @@
package org.extism.sdk;
import java.util.Optional;
public interface ExtismFunction<T extends HostUserData> {
void invoke(
ExtismCurrentPlugin plugin,
LibExtism.ExtismVal[] params,
LibExtism.ExtismVal[] returns,
Optional<T> data
);
}

View File

@@ -0,0 +1,81 @@
package org.extism.sdk;
import com.sun.jna.Pointer;
import com.sun.jna.PointerType;
import java.util.Arrays;
import java.util.Optional;
public class HostFunction<T extends HostUserData> {
private final LibExtism.InternalExtismFunction callback;
public final Pointer pointer;
public final String name;
public final LibExtism.ExtismValType[] params;
public final LibExtism.ExtismValType[] returns;
public final Optional<T> userData;
public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.ExtismValType[] returns, ExtismFunction f, Optional<T> userData) {
this.name = name;
this.params = params;
this.returns = returns;
this.userData = userData;
this.callback = (Pointer currentPlugin,
LibExtism.ExtismVal inputs,
int nInputs,
LibExtism.ExtismVal outs,
int nOutputs,
Pointer data) -> {
LibExtism.ExtismVal[] outputs = (LibExtism.ExtismVal []) outs.toArray(nOutputs);
f.invoke(
new ExtismCurrentPlugin(currentPlugin),
(LibExtism.ExtismVal []) inputs.toArray(nInputs),
outputs,
userData
);
for (LibExtism.ExtismVal output : outputs) {
convertOutput(output, output);
}
};
this.pointer = LibExtism.INSTANCE.extism_function_new(
this.name,
Arrays.stream(this.params).mapToInt(r -> r.v).toArray(),
this.params.length,
Arrays.stream(this.returns).mapToInt(r -> r.v).toArray(),
this.returns.length,
this.callback,
userData.map(PointerType::getPointer).orElse(null),
null
);
}
void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
if (fromHostFunction.t != original.t)
throw new ExtismException(String.format("Output type mismatch, got %d but expected %d", fromHostFunction.t, original.t));
if (fromHostFunction.t == LibExtism.ExtismValType.I32.v) {
original.v.setType(Integer.TYPE);
original.v.i32 = fromHostFunction.v.i32;
} else if (fromHostFunction.t == LibExtism.ExtismValType.I64.v) {
original.v.setType(Long.TYPE);
original.v.i64 = fromHostFunction.v.i64;
} else if (fromHostFunction.t == LibExtism.ExtismValType.F32.v) {
original.v.setType(Float.TYPE);
original.v.f32 = fromHostFunction.v.f32;
} else if (fromHostFunction.t == LibExtism.ExtismValType.F64.v) {
original.v.setType(Double.TYPE);
original.v.f64 = fromHostFunction.v.f64;
} else
throw new ExtismException(String.format("Unsupported return type: %s", original.t));
}
}

View File

@@ -0,0 +1,7 @@
package org.extism.sdk;
import com.sun.jna.PointerType;
public class HostUserData extends PointerType {
}

View File

@@ -1,8 +1,6 @@
package org.extism.sdk;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.*;
/**
* Wrapper around the Extism library.
@@ -15,7 +13,80 @@ public interface LibExtism extends Library {
*/
LibExtism INSTANCE = Native.load("extism", LibExtism.class);
interface InternalExtismFunction extends Callback {
void invoke(
Pointer currentPlugin,
ExtismVal inputs,
int nInputs,
ExtismVal outputs,
int nOutputs,
Pointer data
);
}
@Structure.FieldOrder({"t", "v"})
class ExtismVal extends Structure {
public int t;
public ExtismValUnion v;
}
class ExtismValUnion extends Union {
public int i32;
public long i64;
public float f32;
public double f64;
}
enum ExtismValType {
I32(0),
I64(1),
F32(2),
F64(3),
V128(4),
FuncRef(5),
ExternRef(6);
public final int v;
ExtismValType(int value) {
this.v = value;
}
}
Pointer extism_function_new(String name,
int[] inputs,
int nInputs,
int[] outputs,
int nOutputs,
InternalExtismFunction func,
Pointer userData,
Pointer freeUserData);
/**
* Get the length of an allocated block
* NOTE: this should only be called from host functions.
*/
int extism_current_plugin_memory_length(Pointer plugin, long n);
/**
* Returns a pointer to the memory of the currently running plugin
* NOTE: this should only be called from host functions.
*/
Pointer extism_current_plugin_memory(Pointer plugin);
/**
* Allocate a memory block in the currently running plugin
* NOTE: this should only be called from host functions.
*/
int extism_current_plugin_memory_alloc(Pointer plugin, long n);
/**
* Free an allocated memory block
* NOTE: this should only be called from host functions.
*/
void extism_current_plugin_memory_free(Pointer plugin, long ptr);
/**
* Create a new context
*/
Pointer extism_context_new();
@@ -61,24 +132,13 @@ public interface LibExtism extends Library {
* @param withWASI enables/disables WASI
* @return id of the plugin or {@literal -1} in case of error
*/
int extism_plugin_new(Pointer contextPointer, byte[] wasm, long wasmSize, Pointer functions, int nFunctions, boolean withWASI);
int extism_plugin_new(Pointer contextPointer, byte[] wasm, long wasmSize, Pointer[] functions, int nFunctions, boolean withWASI);
/**
* Returns the Extism version string
*/
String extism_version();
/**
* Create a new plugin.
*
* @param contextPointer pointer to the {@link Context}.
* @param wasm is a WASM module (wat or wasm) or a JSON encoded manifest
* @param length the length of the `wasm` parameter
* @param withWASI enables/disables WASI
* @return id of the plugin or {@literal -1} in case of error
* @see #extism_plugin_new(long, byte[], long, boolean)
*/
int extism_plugin_new(Pointer contextPointer, byte[] wasm, int length, boolean withWASI);
/**
* Calls a function from the @{@link Plugin} at the given {@code pluginIndex}.
@@ -112,7 +172,7 @@ public interface LibExtism extends Library {
/**
* Update a plugin, keeping the existing ID.
* Similar to {@link #extism_plugin_new(long, byte[], long, boolean)} but takes an {@code pluginIndex} argument to specify which plugin to update.
* Similar to {@link #extism_plugin_new(Pointer, byte[], long, Pointer[], int, boolean)} but takes an {@code pluginIndex} argument to specify which plugin to update.
* Note: Memory for this plugin will be reset upon update.
*
* @param contextPointer
@@ -124,7 +184,7 @@ public interface LibExtism extends Library {
* @param withWASI
* @return {@literal true} if update was successful
*/
boolean extism_plugin_update(Pointer contextPointer, int pluginIndex, byte[] wasm, int length, Pointer functions, int nFunctions, boolean withWASI);
boolean extism_plugin_update(Pointer contextPointer, int pluginIndex, byte[] wasm, int length, Pointer[] functions, int nFunctions, boolean withWASI);
/**
* Remove a plugin from the registry and free associated memory.

View File

@@ -28,15 +28,27 @@ public class Plugin implements AutoCloseable {
*
* @param context The context to manage the plugin
* @param manifestBytes The manifest for the plugin
* @param functions The Host functions for th eplugin
* @param withWASI Set to true to enable WASI
*/
public Plugin(Context context, byte[] manifestBytes, boolean withWASI) {
public Plugin(Context context, byte[] manifestBytes, boolean withWASI, HostFunction[] functions) {
Objects.requireNonNull(context, "context");
Objects.requireNonNull(manifestBytes, "manifestBytes");
Pointer[] ptrArr = new Pointer[functions == null ? 0 : functions.length];
if (functions != null)
for (int i = 0; i < functions.length; i++) {
ptrArr[i] = functions[i].pointer;
}
Pointer contextPointer = context.getPointer();
int index = LibExtism.INSTANCE.extism_plugin_new(contextPointer, manifestBytes, manifestBytes.length, null, 0, withWASI);
int index = LibExtism.INSTANCE.extism_plugin_new(contextPointer, manifestBytes, manifestBytes.length,
ptrArr,
functions == null ? 0 : functions.length,
withWASI);
if (index == -1) {
String error = context.error(this);
throw new ExtismException(error);
@@ -46,8 +58,8 @@ public class Plugin implements AutoCloseable {
this.context = context;
}
public Plugin(Context context, Manifest manifest, boolean withWASI) {
this(context, serialize(manifest), withWASI);
public Plugin(Context context, Manifest manifest, boolean withWASI, HostFunction[] functions) {
this(context, serialize(manifest), withWASI, functions);
}
private static byte[] serialize(Manifest manifest) {
@@ -112,8 +124,8 @@ public class Plugin implements AutoCloseable {
* @param withWASI Set to true to enable WASI
* @return {@literal true} if update was successful
*/
public boolean update(Manifest manifest, boolean withWASI) {
return update(serialize(manifest), withWASI);
public boolean update(Manifest manifest, boolean withWASI, HostFunction[] functions) {
return update(serialize(manifest), withWASI, functions);
}
/**
@@ -123,9 +135,19 @@ public class Plugin implements AutoCloseable {
* @param withWASI Set to true to enable WASI
* @return {@literal true} if update was successful
*/
public boolean update(byte[] manifestBytes, boolean withWASI) {
public boolean update(byte[] manifestBytes, boolean withWASI, HostFunction[] functions) {
Objects.requireNonNull(manifestBytes, "manifestBytes");
return LibExtism.INSTANCE.extism_plugin_update(context.getPointer(), index, manifestBytes, manifestBytes.length, null, 0, withWASI);
Pointer[] ptrArr = new Pointer[functions == null ? 0 : functions.length];
if (functions != null)
for (int i = 0; i < functions.length; i++) {
ptrArr[i] = functions[i].pointer;
}
return LibExtism.INSTANCE.extism_plugin_update(context.getPointer(), index, manifestBytes, manifestBytes.length,
ptrArr,
functions == null ? 0 : functions.length,
withWASI);
}
/**

View File

@@ -1,18 +1,14 @@
package org.extism.sdk.support;
import com.google.gson.FieldNamingPolicy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonParseException;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.google.gson.*;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;
import com.google.gson.stream.JsonWriter;
import org.extism.sdk.manifest.Manifest;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
public class JsonSerde {
@@ -23,7 +19,7 @@ public class JsonSerde {
GSON = new GsonBuilder() //
.disableHtmlEscaping() //
// needed to convert the byte[] to a base64 encoded String
.registerTypeHierarchyAdapter(byte[].class, new ByteArrayToBase64TypeAdapter()) //
.registerTypeHierarchyAdapter(byte[].class, new ByteArrayAdapter()) //
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) //
.setPrettyPrinting() //
.create();
@@ -33,14 +29,28 @@ public class JsonSerde {
return GSON.toJson(manifest);
}
private static class ByteArrayToBase64TypeAdapter implements JsonSerializer<byte[]>, JsonDeserializer<byte[]> {
private static class ByteArrayAdapter extends TypeAdapter<byte[]> {
public byte[] deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException {
return Base64.getDecoder().decode(json.getAsString());
@Override
public void write(JsonWriter out, byte[] byteValue) throws IOException {
out.value(new String(Base64.getEncoder().encode(byteValue)));
}
public JsonElement serialize(byte[] src, Type typeOfSrc, JsonSerializationContext context) {
return new JsonPrimitive(Base64.getEncoder().withoutPadding().encodeToString(src));
@Override
public byte[] read(JsonReader in) {
try {
if (in.peek() == JsonToken.NULL) {
in.nextNull();
return new byte[]{};
}
String byteValue = in.nextString();
if (byteValue != null) {
return Base64.getDecoder().decode(byteValue);
}
return new byte[]{};
} catch (Exception e) {
throw new JsonParseException(e);
}
}
}
}

View File

@@ -1,12 +1,12 @@
package org.extism.sdk;
import com.sun.jna.Pointer;
import org.extism.sdk.manifest.Manifest;
import org.extism.sdk.manifest.MemoryOptions;
import org.extism.sdk.wasm.WasmSourceResolver;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
import java.util.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.extism.sdk.TestWasmSources.CODE;
@@ -42,14 +42,12 @@ public class PluginTests {
assertThat(output).isEqualTo("{\"count\": 3}");
}
// TODO This test breaks on CI with error:
// data did not match any variant of untagged enum Wasm at line 8 column 3
// @Test
// public void shouldInvokeFunctionFromByteArrayWasmSource() {
// var manifest = new Manifest(CODE.byteArrayWasmSource());
// var output = Extism.invokeFunction(manifest, "count_vowels", "Hello World");
// assertThat(output).isEqualTo("{\"count\": 3}");
// }
@Test
public void shouldInvokeFunctionFromByteArrayWasmSource() {
var manifest = new Manifest(CODE.byteArrayWasmSource());
var output = Extism.invokeFunction(manifest, "count_vowels", "Hello World");
assertThat(output).isEqualTo("{\"count\": 3}");
}
@Test
public void shouldFailToInvokeUnknownFunction() {
@@ -80,7 +78,7 @@ public class PluginTests {
var input = "Hello World";
try (var ctx = new Context()) {
try (var plugin = ctx.newPlugin(manifest, false)) {
try (var plugin = ctx.newPlugin(manifest, false, null)) {
var output = plugin.call(functionName, input);
assertThat(output).isEqualTo("{\"count\": 3}");
}
@@ -94,7 +92,7 @@ public class PluginTests {
var input = "Hello World";
try (var ctx = new Context()) {
try (var plugin = ctx.newPlugin(manifest, false)) {
try (var plugin = ctx.newPlugin(manifest, false, null)) {
var output = plugin.call(functionName, input);
assertThat(output).isEqualTo("{\"count\": 3}");
@@ -104,4 +102,108 @@ public class PluginTests {
}
}
@Test
public void shouldAllowInvokeHostFunctionFromPDK() {
var parametersTypes = new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64};
var resultsTypes = new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64};
class MyUserData extends HostUserData {
private String data1;
private int data2;
public MyUserData(String data1, int data2) {
super();
this.data1 = data1;
this.data2 = data2;
}
}
ExtismFunction helloWorldFunction = (ExtismFunction<MyUserData>) (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
System.out.println(String.format("Input string received from plugin, %s", plugin.inputString(params[0])));
int offs = plugin.alloc(4);
Pointer mem = plugin.memory();
mem.write(offs, "test".getBytes(), 0, 4);
returns[0].v.i64 = offs;
data.ifPresent(d -> System.out.println(String.format("Host user data, %s, %d", d.data1, d.data2)));
};
HostFunction helloWorld = new HostFunction<>(
"hello_world",
parametersTypes,
resultsTypes,
helloWorldFunction,
Optional.of(new MyUserData("test", 2))
);
HostFunction[] functions = {helloWorld};
try (var ctx = new Context()) {
Manifest manifest = new Manifest(Arrays.asList(CODE.pathWasmFunctionsSource()));
String functionName = "count_vowels";
try (var plugin = ctx.newPlugin(manifest, true, functions)) {
var output = plugin.call(functionName, "this is a test");
assertThat(output).isEqualTo("test");
}
}
}
@Test
public void shouldAllowInvokeHostFunctionWithoutUserData() {
var parametersTypes = new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64};
var resultsTypes = new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64};
ExtismFunction helloWorldFunction = (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
System.out.println(String.format("Input string received from plugin, %s", plugin.inputString(params[0])));
int offs = plugin.alloc(4);
Pointer mem = plugin.memory();
mem.write(offs, "test".getBytes(), 0, 4);
returns[0].v.i64 = offs;
assertThat(data.isEmpty());
};
HostFunction helloWorld = new HostFunction<>(
"hello_world",
parametersTypes,
resultsTypes,
helloWorldFunction,
Optional.empty()
);
HostFunction[] functions = {helloWorld};
try (var ctx = new Context()) {
Manifest manifest = new Manifest(Arrays.asList(CODE.pathWasmFunctionsSource()));
String functionName = "count_vowels";
try (var plugin = ctx.newPlugin(manifest, true, functions)) {
var output = plugin.call(functionName, "this is a test");
assertThat(output).isEqualTo("test");
}
}
}
@Test
public void shouldFailToInvokeUnknownHostFunction() {
try (var ctx = new Context()) {
Manifest manifest = new Manifest(Arrays.asList(CODE.pathWasmFunctionsSource()));
String functionName = "count_vowels";
try {
var plugin = ctx.newPlugin(manifest, true, null);
plugin.call(functionName, "this is a test");
} catch (ExtismException e) {
assertThat(e.getMessage()).contains("unknown import: `env::hello_world` has not been defined");
}
}
}
}

View File

@@ -16,19 +16,28 @@ public enum TestWasmSources {
public Path getWasmFilePath() {
return Paths.get(WASM_LOCATION, "code.wasm");
}
public Path getWasmFunctionsFilePath() {
return Paths.get(WASM_LOCATION, "code-functions.wasm");
}
};
public static final String WASM_LOCATION = "src/test/resources";
public abstract Path getWasmFilePath();
public abstract Path getWasmFunctionsFilePath();
public PathWasmSource pathWasmSource() {
return resolvePathWasmSource(getWasmFilePath());
}
public PathWasmSource pathWasmFunctionsSource() {
return resolvePathWasmSource(getWasmFunctionsFilePath());
}
public ByteArrayWasmSource byteArrayWasmSource() {
try {
var wasmBytes = Files.readAllBytes(getWasmFilePath());
byte[] wasmBytes = Files.readAllBytes(getWasmFilePath());
return new WasmSourceResolver().resolve("wasm@" + Arrays.hashCode(wasmBytes), wasmBytes);
} catch (IOException ioe) {
throw new RuntimeException(ioe);

Binary file not shown.