mirror of
https://github.com/eth-act/ere.git
synced 2026-04-03 03:00:17 -04:00
Merge pull request #18 from eth-applied-research-group/kw/input-dyn-objects
feat: Modify Input to no longer use bincode implicitly
This commit is contained in:
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -2421,6 +2421,16 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "erased-serde"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e004d887f51fcb9fef17317a2f3525c887d8aa3f4f50fed920816a688284a5b7"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ere-jolt"
|
||||
version = "0.1.0"
|
||||
@@ -9325,6 +9335,12 @@ dependencies = [
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeid"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c"
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
@@ -10163,7 +10179,9 @@ dependencies = [
|
||||
"anyhow",
|
||||
"auto_impl",
|
||||
"bincode",
|
||||
"erased-serde",
|
||||
"indexmap 2.9.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
@@ -40,21 +40,22 @@ pub fn preprocess_verifier(
|
||||
pub fn verify_generic(
|
||||
proof: jolt::JoltHyperKZGProof,
|
||||
// TODO: input should be private input
|
||||
inputs: Input,
|
||||
outputs: Input,
|
||||
_inputs: Input,
|
||||
_outputs: Input,
|
||||
preprocessing: jolt::JoltVerifierPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript>,
|
||||
) -> bool {
|
||||
use jolt::{Jolt, RV32IJoltVM, tracer};
|
||||
|
||||
let preprocessing = std::sync::Arc::new(preprocessing);
|
||||
let preprocessing = (*preprocessing).clone();
|
||||
let mut io_device = tracer::JoltDevice::new(
|
||||
let io_device = tracer::JoltDevice::new(
|
||||
preprocessing.memory_layout.max_input_size,
|
||||
preprocessing.memory_layout.max_output_size,
|
||||
);
|
||||
|
||||
io_device.inputs = inputs.bytes().to_vec();
|
||||
io_device.outputs = outputs.bytes().to_vec();
|
||||
// TODO: FIXME
|
||||
// io_device.inputs = inputs.bytes().to_vec();
|
||||
// io_device.outputs = outputs.bytes().to_vec();
|
||||
|
||||
RV32IJoltVM::verify(
|
||||
preprocessing,
|
||||
@@ -69,14 +70,15 @@ pub fn verify_generic(
|
||||
pub fn prove_generic(
|
||||
program: &jolt::host::Program,
|
||||
preprocessing: jolt::JoltProverPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript>,
|
||||
inputs: &Input,
|
||||
_inputs: &Input,
|
||||
) -> (Vec<u8>, jolt::JoltHyperKZGProof) {
|
||||
use jolt::{Jolt, RV32IJoltVM};
|
||||
|
||||
let mut program = program.clone();
|
||||
|
||||
// Convert inputs to a flat vector
|
||||
let input_bytes = inputs.bytes().to_vec();
|
||||
// TODO: FIXME
|
||||
let input_bytes = Vec::new();
|
||||
|
||||
let (io_device, trace) = program.trace(&input_bytes);
|
||||
|
||||
|
||||
@@ -46,27 +46,29 @@ impl EreJolt {
|
||||
program: <JOLT_TARGET as Compiler>::Program,
|
||||
_resource_type: ProverResourceType,
|
||||
) -> Self {
|
||||
EreJolt { program: program }
|
||||
EreJolt { program }
|
||||
}
|
||||
}
|
||||
impl zkVM for EreJolt {
|
||||
fn execute(
|
||||
&self,
|
||||
inputs: &zkvm_interface::Input,
|
||||
_inputs: &Input,
|
||||
) -> Result<zkvm_interface::ProgramExecutionReport, zkVMError> {
|
||||
// TODO: check ProgramSummary
|
||||
let summary = self
|
||||
.program
|
||||
.clone()
|
||||
.trace_analyze::<jolt::F>(inputs.bytes());
|
||||
let trace_len = summary.trace_len();
|
||||
// TODO: FIXME
|
||||
// let summary = self
|
||||
// .program
|
||||
// .clone()
|
||||
// .trace_analyze::<jolt::F>(inputs.bytes());
|
||||
// let trace_len = summary.trace_len();
|
||||
let trace_len = 0;
|
||||
|
||||
Ok(ProgramExecutionReport::new(trace_len as u64))
|
||||
}
|
||||
|
||||
fn prove(
|
||||
&self,
|
||||
inputs: &zkvm_interface::Input,
|
||||
inputs: &Input,
|
||||
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), zkVMError> {
|
||||
// TODO: make this stateful and do in setup since its expensive and should be done once per program;
|
||||
let preprocessed_key = preprocess_prover(&self.program);
|
||||
@@ -88,7 +90,7 @@ impl zkVM for EreJolt {
|
||||
|
||||
let mut outputs = Input::new();
|
||||
assert!(public_inputs.is_empty());
|
||||
outputs.write(&public_inputs).unwrap();
|
||||
outputs.write(public_inputs);
|
||||
|
||||
// TODO: I don't think we should require the inputs when verifying
|
||||
let inputs = Input::new();
|
||||
@@ -97,7 +99,7 @@ impl zkVM for EreJolt {
|
||||
if valid {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(JoltError::ProofVerificationFailed).map_err(zkVMError::from)
|
||||
Err(zkVMError::from(JoltError::ProofVerificationFailed))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,7 +136,7 @@ mod tests {
|
||||
let test_guest_path = get_compile_test_guest_program_path();
|
||||
let program = JOLT_TARGET::compile(&test_guest_path).unwrap();
|
||||
let mut inputs = Input::new();
|
||||
inputs.write(&(1 as u32)).unwrap();
|
||||
inputs.write(1 as u32);
|
||||
|
||||
let zkvm = EreJolt::new(program, ProverResourceType::Cpu);
|
||||
let _execution = zkvm.execute(&inputs).unwrap();
|
||||
|
||||
@@ -24,7 +24,7 @@ pub(crate) fn package_name_from_manifest(manifest_path: &Path) -> Result<String,
|
||||
|
||||
/// Serializes the public input (as raw bytes) and proof into a single byte vector
|
||||
pub fn serialize_public_input_with_proof(
|
||||
public_input: &Vec<u8>,
|
||||
public_input: &[u8],
|
||||
proof: &JoltHyperKZGProof,
|
||||
) -> Result<Vec<u8>, SerializationError> {
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
@@ -12,7 +12,8 @@ use openvm_stark_sdk::config::{
|
||||
};
|
||||
use openvm_transpiler::elf::Elf;
|
||||
use zkvm_interface::{
|
||||
Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM, zkVMError,
|
||||
Compiler, Input, InputItem, ProgramExecutionReport, ProgramProvingReport,
|
||||
ProverResourceType, zkVM, zkVMError,
|
||||
};
|
||||
|
||||
mod error;
|
||||
@@ -58,7 +59,7 @@ impl EreOpenVM {
|
||||
impl zkVM for EreOpenVM {
|
||||
fn execute(
|
||||
&self,
|
||||
inputs: &zkvm_interface::Input,
|
||||
inputs: &Input,
|
||||
) -> Result<zkvm_interface::ProgramExecutionReport, zkVMError> {
|
||||
let sdk = Sdk::new();
|
||||
let vm_cfg = SdkVmConfig::builder()
|
||||
@@ -74,8 +75,11 @@ impl zkVM for EreOpenVM {
|
||||
.map_err(OpenVMError::from)?;
|
||||
|
||||
let mut stdin = StdIn::default();
|
||||
for input in inputs.chunked_iter() {
|
||||
stdin.write_bytes(input);
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => stdin.write(serialize),
|
||||
InputItem::Bytes(items) => stdin.write_bytes(items),
|
||||
}
|
||||
}
|
||||
|
||||
let _outputs = sdk
|
||||
@@ -88,7 +92,7 @@ impl zkVM for EreOpenVM {
|
||||
|
||||
fn prove(
|
||||
&self,
|
||||
inputs: &zkvm_interface::Input,
|
||||
inputs: &Input,
|
||||
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), zkVMError> {
|
||||
// TODO: We need a stateful version in order to not spend a lot of time
|
||||
// TODO doing things like computing the pk and vk.
|
||||
@@ -107,8 +111,11 @@ impl zkVM for EreOpenVM {
|
||||
.map_err(OpenVMError::from)?;
|
||||
|
||||
let mut stdin = StdIn::default();
|
||||
for input in inputs.chunked_iter() {
|
||||
stdin.write_bytes(input);
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => stdin.write(serialize),
|
||||
InputItem::Bytes(items) => stdin.write_bytes(items),
|
||||
}
|
||||
}
|
||||
|
||||
let app_config = AppConfig::new(FriParameters::standard_fast(), vm_cfg);
|
||||
@@ -193,7 +200,7 @@ mod tests {
|
||||
// Panics because the program expects input arguments, but we supply none
|
||||
let test_guest_path = get_compile_test_guest_program_path();
|
||||
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
|
||||
let empty_input = zkvm_interface::Input::new();
|
||||
let empty_input = Input::new();
|
||||
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
|
||||
|
||||
zkvm.execute(&empty_input).unwrap();
|
||||
@@ -203,8 +210,8 @@ mod tests {
|
||||
fn test_execute() {
|
||||
let test_guest_path = get_compile_test_guest_program_path();
|
||||
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
|
||||
let mut input = zkvm_interface::Input::new();
|
||||
input.write(&10u64).unwrap();
|
||||
let mut input = Input::new();
|
||||
input.write(10u64);
|
||||
|
||||
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
|
||||
zkvm.execute(&input).unwrap();
|
||||
@@ -214,8 +221,8 @@ mod tests {
|
||||
fn test_prove_verify() {
|
||||
let test_guest_path = get_compile_test_guest_program_path();
|
||||
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
|
||||
let mut input = zkvm_interface::Input::new();
|
||||
input.write(&10u64).unwrap();
|
||||
let mut input = Input::new();
|
||||
input.write(10u64);
|
||||
|
||||
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
|
||||
let (proof, _) = zkvm.prove(&input).unwrap();
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use pico_sdk::client::DefaultProverClient;
|
||||
use std::process::Command;
|
||||
use zkvm_interface::{Compiler, ProgramProvingReport, ProverResourceType, zkVM, zkVMError};
|
||||
use zkvm_interface::{
|
||||
Compiler, InputItem, ProgramProvingReport, ProverResourceType, zkVM, zkVMError,
|
||||
};
|
||||
|
||||
mod error;
|
||||
use error::PicoError;
|
||||
@@ -76,8 +78,11 @@ impl zkVM for ErePico {
|
||||
let client = DefaultProverClient::new(&self.program);
|
||||
|
||||
let mut stdin = client.new_stdin_builder();
|
||||
for input in inputs.chunked_iter() {
|
||||
stdin.write_slice(input);
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => stdin.write(serialize),
|
||||
InputItem::Bytes(items) => stdin.write_slice(items),
|
||||
}
|
||||
}
|
||||
let now = std::time::Instant::now();
|
||||
let meta_proof = client.prove(stdin).expect("Failed to generate proof");
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use compile::compile_risczero_program;
|
||||
use risc0_zkvm::{ExecutorEnv, Receipt, default_executor, default_prover};
|
||||
use zkvm_interface::{
|
||||
Compiler, Input, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM,
|
||||
zkVMError,
|
||||
Compiler, Input, InputItem, ProgramExecutionReport, ProgramProvingReport, ProverResourceType,
|
||||
zkVM, zkVMError,
|
||||
};
|
||||
|
||||
mod compile;
|
||||
@@ -38,16 +38,25 @@ impl EreRisc0 {
|
||||
|
||||
pub struct EreRisc0 {
|
||||
program: <RV32_IM_RISCZERO_ZKVM_ELF as Compiler>::Program,
|
||||
#[allow(dead_code)]
|
||||
resource_type: ProverResourceType,
|
||||
}
|
||||
|
||||
impl zkVM for EreRisc0 {
|
||||
fn execute(&self, inputs: &Input) -> Result<ProgramExecutionReport, zkVMError> {
|
||||
let executor = default_executor();
|
||||
let env = ExecutorEnv::builder()
|
||||
.write_slice(inputs.bytes())
|
||||
.build()
|
||||
.map_err(|err| zkVMError::Other(err.into()))?;
|
||||
let mut env = ExecutorEnv::builder();
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => {
|
||||
env.write(serialize).unwrap();
|
||||
}
|
||||
InputItem::Bytes(items) => {
|
||||
env.write_slice(&items);
|
||||
}
|
||||
}
|
||||
}
|
||||
let env = env.build().map_err(|err| zkVMError::Other(err.into()))?;
|
||||
|
||||
let session_info = executor
|
||||
.execute(env, &self.program.elf)
|
||||
@@ -60,10 +69,18 @@ impl zkVM for EreRisc0 {
|
||||
|
||||
fn prove(&self, inputs: &Input) -> Result<(Vec<u8>, ProgramProvingReport), zkVMError> {
|
||||
let prover = default_prover();
|
||||
let env = ExecutorEnv::builder()
|
||||
.write_slice(inputs.bytes())
|
||||
.build()
|
||||
.map_err(|err| zkVMError::Other(err.into()))?;
|
||||
let mut env = ExecutorEnv::builder();
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => {
|
||||
env.write(serialize).unwrap();
|
||||
}
|
||||
InputItem::Bytes(items) => {
|
||||
env.write_slice(&items);
|
||||
}
|
||||
}
|
||||
}
|
||||
let env = env.build().map_err(|err| zkVMError::Other(err.into()))?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
let prove_info = prover
|
||||
@@ -116,8 +133,8 @@ mod prove_tests {
|
||||
let mut input_builder = Input::new();
|
||||
let n: u32 = 42;
|
||||
let a: u16 = 42;
|
||||
input_builder.write(&n).unwrap();
|
||||
input_builder.write(&a).unwrap();
|
||||
input_builder.write(n);
|
||||
input_builder.write(a);
|
||||
|
||||
let zkvm = EreRisc0::new(program, ProverResourceType::Cpu);
|
||||
|
||||
@@ -180,8 +197,8 @@ mod execute_tests {
|
||||
let mut input_builder = Input::new();
|
||||
let n: u32 = 42;
|
||||
let a: u16 = 42;
|
||||
input_builder.write(&n).unwrap();
|
||||
input_builder.write(&a).unwrap();
|
||||
input_builder.write(n);
|
||||
input_builder.write(a);
|
||||
|
||||
let zkvm = EreRisc0::new(program, ProverResourceType::Cpu);
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ use sp1_sdk::{
|
||||
};
|
||||
use tracing::info;
|
||||
use zkvm_interface::{
|
||||
Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM, zkVMError,
|
||||
Compiler, Input, InputItem, ProgramExecutionReport, ProgramProvingReport,
|
||||
ProverResourceType, zkVM, zkVMError,
|
||||
};
|
||||
|
||||
mod compile;
|
||||
@@ -115,11 +116,14 @@ impl EreSP1 {
|
||||
impl zkVM for EreSP1 {
|
||||
fn execute(
|
||||
&self,
|
||||
inputs: &zkvm_interface::Input,
|
||||
inputs: &Input,
|
||||
) -> Result<zkvm_interface::ProgramExecutionReport, zkVMError> {
|
||||
let mut stdin = SP1Stdin::new();
|
||||
for input in inputs.chunked_iter() {
|
||||
stdin.write_slice(input);
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => stdin.write(serialize),
|
||||
InputItem::Bytes(items) => stdin.write_slice(items),
|
||||
}
|
||||
}
|
||||
|
||||
let (_, exec_report) = self.client.execute(&self.program, &stdin)?;
|
||||
@@ -140,8 +144,11 @@ impl zkVM for EreSP1 {
|
||||
info!("Generating proof…");
|
||||
|
||||
let mut stdin = SP1Stdin::new();
|
||||
for input in inputs.chunked_iter() {
|
||||
stdin.write_slice(input);
|
||||
for input in inputs.iter() {
|
||||
match input {
|
||||
InputItem::Object(serialize) => stdin.write(serialize),
|
||||
InputItem::Bytes(items) => stdin.write_slice(items),
|
||||
};
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
@@ -197,8 +204,8 @@ mod execute_tests {
|
||||
let mut input_builder = Input::new();
|
||||
let n: u32 = 42;
|
||||
let a: u16 = 42;
|
||||
input_builder.write(&n).unwrap();
|
||||
input_builder.write(&a).unwrap();
|
||||
input_builder.write(n);
|
||||
input_builder.write(a);
|
||||
|
||||
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
|
||||
|
||||
@@ -257,8 +264,8 @@ mod prove_tests {
|
||||
let mut input_builder = Input::new();
|
||||
let n: u32 = 42;
|
||||
let a: u16 = 42;
|
||||
input_builder.write(&n).unwrap();
|
||||
input_builder.write(&a).unwrap();
|
||||
input_builder.write(n);
|
||||
input_builder.write(a);
|
||||
|
||||
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ bincode = "1.3"
|
||||
indexmap = { version = "2.9.0", features = ["serde"] }
|
||||
thiserror = "2"
|
||||
auto_impl = "1.0"
|
||||
erased-serde = "0.4.6"
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = "1"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,129 +1,220 @@
|
||||
use erased_serde::Serialize as ErasedSerialize;
|
||||
use serde::Serialize;
|
||||
|
||||
pub enum InputItem {
|
||||
/// A serializable object stored as a trait object
|
||||
Object(Box<dyn ErasedSerialize>),
|
||||
/// Pre-serialized bytes (e.g., from bincode)
|
||||
Bytes(Vec<u8>),
|
||||
}
|
||||
|
||||
/// Represents a builder for input data to be passed to a ZKVM guest program.
|
||||
/// Values are serialized sequentially into an internal byte buffer.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Input {
|
||||
buf: Vec<u8>,
|
||||
ranges: Vec<(usize, usize)>,
|
||||
items: Vec<InputItem>,
|
||||
}
|
||||
impl Default for Input {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Input {
|
||||
/// Create an empty input buffer.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
Self {
|
||||
items: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a value, serializing it with `bincode`.
|
||||
pub fn write<T: Serialize>(&mut self, value: &T) -> Result<(), bincode::Error> {
|
||||
let start = self.buf.len();
|
||||
bincode::serialize_into(&mut self.buf, value)?;
|
||||
let end = self.buf.len();
|
||||
self.ranges.push((start, end - start));
|
||||
Ok(())
|
||||
/// Write a serializable value as a trait object
|
||||
pub fn write<T: Serialize + 'static>(&mut self, value: T) {
|
||||
self.items.push(InputItem::Object(Box::new(value)));
|
||||
}
|
||||
|
||||
pub fn write_slice(&mut self, slice: &[u8]) {
|
||||
let start = self.buf.len();
|
||||
self.buf.extend_from_slice(slice);
|
||||
let end = self.buf.len();
|
||||
self.ranges.push((start, end - start));
|
||||
/// Write pre-serialized bytes directly
|
||||
pub fn write_bytes(&mut self, bytes: Vec<u8>) {
|
||||
self.items.push(InputItem::Bytes(bytes));
|
||||
}
|
||||
|
||||
/// Number of elements written.
|
||||
/// Get the number of items stored
|
||||
pub fn len(&self) -> usize {
|
||||
self.ranges.len()
|
||||
self.items.len()
|
||||
}
|
||||
|
||||
/// Check if the buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.ranges.is_empty()
|
||||
self.items.is_empty()
|
||||
}
|
||||
|
||||
/// Entire concatenated payload as one slice.
|
||||
pub fn bytes(&self) -> &[u8] {
|
||||
&self.buf
|
||||
/// Iterate over the items
|
||||
pub fn iter(&self) -> std::slice::Iter<InputItem> {
|
||||
self.items.iter()
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: Implement methods to work with the enum
|
||||
impl InputItem {
|
||||
/// Serialize this item to bytes using the specified serializer
|
||||
pub fn serialize_with<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
InputItem::Object(obj) => erased_serde::serialize(obj.as_ref(), serializer),
|
||||
InputItem::Bytes(bytes) => {
|
||||
// Serialize the bytes as a byte array
|
||||
bytes.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator over individual chunks (the originally written objects).
|
||||
pub fn chunked_iter(&self) -> impl ExactSizeIterator<Item = &[u8]> + '_ {
|
||||
self.ranges.iter().map(|&(s, len)| &self.buf[s..s + len])
|
||||
}
|
||||
|
||||
/// Byte‑wise iterator (rarely needed).
|
||||
pub fn iter(&self) -> std::slice::Iter<'_, u8> {
|
||||
self.buf.iter()
|
||||
/// Get the item as bytes (serialize objects, return bytes directly)
|
||||
pub fn as_bytes(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
match self {
|
||||
InputItem::Object(obj) => {
|
||||
let mut buf = Vec::new();
|
||||
let mut serializer =
|
||||
bincode::Serializer::new(&mut buf, bincode::DefaultOptions::new());
|
||||
erased_serde::serialize(obj.as_ref(), &mut serializer)?;
|
||||
Ok(buf)
|
||||
}
|
||||
InputItem::Bytes(bytes) => Ok(bytes.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod input_erased_tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[test]
|
||||
fn input_empty() {
|
||||
let input = Input::new();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(input.len(), 0);
|
||||
assert!(input.bytes().is_empty());
|
||||
assert_eq!(input.chunked_iter().count(), 0);
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct Person {
|
||||
name: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_write_and_read() {
|
||||
fn test_write_object() {
|
||||
let mut input = Input::new();
|
||||
let a: u32 = 42;
|
||||
let b: &str = "hello";
|
||||
|
||||
input.write(&a).unwrap();
|
||||
input.write(&b).unwrap();
|
||||
let person = Person {
|
||||
name: "Alice".to_string(),
|
||||
age: 30,
|
||||
};
|
||||
|
||||
// length bookkeeping
|
||||
assert_eq!(input.len(), 2);
|
||||
assert!(!input.is_empty());
|
||||
input.write(person);
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
// chunk iteration and deserialization
|
||||
let chunks: Vec<&[u8]> = input.chunked_iter().collect();
|
||||
assert_eq!(chunks.len(), 2);
|
||||
let a_back: u32 = bincode::deserialize(chunks[0]).unwrap();
|
||||
assert_eq!(a_back, a);
|
||||
let b_back: String = bincode::deserialize(chunks[1]).unwrap();
|
||||
assert_eq!(b_back, b);
|
||||
|
||||
// contiguous bytes match manual serialization
|
||||
let mut expected = Vec::<u8>::new();
|
||||
bincode::serialize_into(&mut expected, &a).unwrap();
|
||||
bincode::serialize_into(&mut expected, &b).unwrap();
|
||||
assert_eq!(input.bytes(), expected.as_slice());
|
||||
|
||||
// iter() covers same length
|
||||
assert_eq!(input.iter().count(), expected.len());
|
||||
match &input.items[0] {
|
||||
InputItem::Object(_) => (), // Success
|
||||
InputItem::Bytes(_) => panic!("Expected Object, got Bytes"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_write_slice() {
|
||||
fn test_write_bytes() {
|
||||
let mut input = Input::new();
|
||||
|
||||
let slice1 = [1, 2, 3, 4];
|
||||
let slice2 = [5, 6, 7, 8, 9];
|
||||
let bytes = vec![1, 2, 3, 4, 5];
|
||||
input.write_bytes(bytes.clone());
|
||||
|
||||
input.write_slice(&slice1);
|
||||
input.write_slice(&slice2);
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
assert_eq!(input.len(), 2);
|
||||
assert!(!input.is_empty());
|
||||
match &input.items[0] {
|
||||
InputItem::Bytes(stored_bytes) => assert_eq!(stored_bytes, &bytes),
|
||||
InputItem::Object(_) => panic!("Expected Bytes, got Object"),
|
||||
}
|
||||
}
|
||||
|
||||
// Check chunked iteration
|
||||
let chunks: Vec<&[u8]> = input.chunked_iter().collect();
|
||||
assert_eq!(chunks.len(), 2);
|
||||
assert_eq!(chunks[0], &slice1);
|
||||
assert_eq!(chunks[1], &slice2);
|
||||
#[test]
|
||||
fn test_write_serialized() {
|
||||
let mut input = Input::new();
|
||||
|
||||
// Check contiguous bytes
|
||||
let mut expected = Vec::<u8>::new();
|
||||
expected.extend_from_slice(&slice1);
|
||||
expected.extend_from_slice(&slice2);
|
||||
assert_eq!(input.bytes(), expected.as_slice());
|
||||
let person = Person {
|
||||
name: "Bob".to_string(),
|
||||
age: 25,
|
||||
};
|
||||
|
||||
assert_eq!(input.iter().count(), slice1.len() + slice2.len());
|
||||
// User serializes themselves and writes bytes
|
||||
let serialized = bincode::serialize(&person).unwrap();
|
||||
input.write_bytes(serialized);
|
||||
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
match &input.items[0] {
|
||||
InputItem::Bytes(_) => (), // Success
|
||||
InputItem::Object(_) => panic!("Expected Bytes, got Object"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_usage() {
|
||||
let mut input = Input::new();
|
||||
|
||||
let person = Person {
|
||||
name: "Charlie".to_string(),
|
||||
age: 35,
|
||||
};
|
||||
|
||||
// Mix different write methods
|
||||
input.write(42i32); // Object
|
||||
let serialized = bincode::serialize(&person).unwrap();
|
||||
input.write_bytes(serialized); // Bytes (serialized)
|
||||
input.write_bytes(vec![10, 20, 30]); // Bytes (raw)
|
||||
input.write("hello".to_string()); // Object
|
||||
|
||||
assert_eq!(input.len(), 4);
|
||||
|
||||
// Verify types
|
||||
match &input.items[0] {
|
||||
InputItem::Object(_) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
match &input.items[1] {
|
||||
InputItem::Bytes(_) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
match &input.items[2] {
|
||||
InputItem::Bytes(_) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
match &input.items[3] {
|
||||
InputItem::Object(_) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_as_bytes() {
|
||||
let mut input = Input::new();
|
||||
|
||||
// Add an object
|
||||
input.write(42i32);
|
||||
|
||||
// Add raw bytes
|
||||
input.write_bytes(vec![1, 2, 3]);
|
||||
|
||||
// Convert both to bytes
|
||||
let obj_bytes = input.items[0].as_bytes().unwrap();
|
||||
let raw_bytes = input.items[1].as_bytes().unwrap();
|
||||
|
||||
// The object should be serialized to some bytes
|
||||
assert!(!obj_bytes.is_empty());
|
||||
|
||||
// The raw bytes should be returned as-is
|
||||
assert_eq!(raw_bytes, vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iteration() {
|
||||
let mut input = Input::new();
|
||||
|
||||
input.write(1);
|
||||
input.write(2);
|
||||
input.write_bytes(vec![3, 4, 5]);
|
||||
|
||||
let count = input.iter().count();
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{path::Path, time::Duration};
|
||||
use thiserror::Error;
|
||||
|
||||
mod input;
|
||||
pub use input::Input;
|
||||
pub use input::{Input, InputItem};
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
/// Compiler trait for compiling programs into an opaque sequence of bytes.
|
||||
|
||||
Reference in New Issue
Block a user