Query input/output rework (#1828)

This PR does a few things:
- unify `Input/DataIdentifier` into a single `Input` query that takes
(channel,idx) and returns a field element
- Output query takes (channel, fe)
- change the related prover functions to reflect this
This interface is used by, but not the same, as the input/output for the
riscv machine.
How to use these to input/output bytes or serialized data is a job of
the runtime implementation.
This commit is contained in:
Leandro Pacheco
2024-09-24 10:58:38 -03:00
committed by GitHub
parent 0a5299e3d1
commit 718c333ff9
25 changed files with 93 additions and 151 deletions

View File

@@ -17,10 +17,10 @@ powdr pil output/sum.asm -o output -f -i 10,2,4,6
The example Rust code verifies that a supplied list of integers sums up to a specified value.
```rust
{{#include ../../../riscv/tests/riscv_data/sum/src/lib.rs}}
{{#include ../../../riscv/tests/riscv_data/sum/src/main.rs}}
```
The function `get_prover_input` reads a number from the list supplied with `-i`.
The function `read_u32` reads a number from the list supplied with `-i`.
This is just a first mechanism to provide access to the outside world.
The plan is to be able to call arbitrary user-defined `ffi` functions that will translate to prover queries,

View File

@@ -13,7 +13,7 @@ see [this](https://github.com/powdr-labs/powdr/issues/814).
Let's use as example test `many chunks` from the `riscv` crate:
```rust
{{#include ../../../riscv/tests/riscv_data/many_chunks/src/lib.rs}}
{{#include ../../../riscv/tests/riscv_data/many_chunks/src/main.rs}}
```
First we need to compile the Rust code to powdr-asm:

View File

@@ -48,7 +48,7 @@ mod vm_processor;
static OUTER_CODE_NAME: &str = "witgen (outer code)";
// TODO change this so that it has functions
// get_input, get_input_from_channel, output_byte
// input_from_channel, output_to_channel
// instead of processing strings.
// but we can only do that once we have fully removed the old query functions.
pub trait QueryCallback<T>: Fn(&str) -> Result<Option<T>, String> + Send + Sync {}

View File

@@ -298,22 +298,12 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback<T>> Symbol
Ok(())
}
fn get_input(&mut self, index: usize) -> Result<Arc<Value<'a, T>>, EvalError> {
if let Some(v) =
(self.query_callback)(&format!("Input({index})")).map_err(EvalError::ProverError)?
{
Ok(Value::FieldElement(v).into())
} else {
Err(EvalError::DataNotAvailable)
}
}
fn get_input_from_channel(
fn input_from_channel(
&mut self,
channel: u32,
index: usize,
) -> Result<Arc<Value<'a, T>>, EvalError> {
if let Some(v) = (self.query_callback)(&format!("DataIdentifier({channel},{index})"))
if let Some(v) = (self.query_callback)(&format!("Input({channel},{index})"))
.map_err(EvalError::ProverError)?
{
Ok(Value::FieldElement(v).into())
@@ -322,8 +312,8 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback<T>> Symbol
}
}
fn output_byte(&mut self, fd: u32, byte: u8) -> Result<(), EvalError> {
if ((self.query_callback)(&format!("Output({fd},{byte})"))
fn output_to_channel(&mut self, fd: u32, elem: T) -> Result<(), EvalError> {
if ((self.query_callback)(&format!("Output({fd},{elem})"))
.map_err(EvalError::ProverError)?)
.is_some()
{

View File

@@ -479,9 +479,9 @@ namespace main_sub__rom(16);
pc' = (1 - first_step') * pc_update;
pol commit X_free_value;
query |__i| std::prover::handle_query(X_free_value, __i, match std::prover::eval(pc) {
2 => std::prelude::Query::Input(1),
4 => std::prelude::Query::Input(std::convert::int(std::prover::eval(CNT) + 1)),
7 => std::prelude::Query::Input(0),
2 => std::prelude::Query::Input(0, 2),
4 => std::prelude::Query::Input(0, std::convert::int(std::prover::eval(CNT) + 2)),
7 => std::prelude::Query::Input(0, 1),
_ => std::prelude::Query::None,
});
1 $ [0, pc, reg_write_X_A, reg_write_X_CNT, instr_jmpz, instr_jmpz_param_l, instr_jmp, instr_jmp_param_l, instr_dec_CNT, instr_assert_zero, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_CNT, read_X_pc] in main__rom::latch $ [main__rom::operation_id, main__rom::p_line, main__rom::p_reg_write_X_A, main__rom::p_reg_write_X_CNT, main__rom::p_instr_jmpz, main__rom::p_instr_jmpz_param_l, main__rom::p_instr_jmp, main__rom::p_instr_jmp_param_l, main__rom::p_instr_dec_CNT, main__rom::p_instr_assert_zero, main__rom::p_instr__jump_to_operation, main__rom::p_instr__reset, main__rom::p_instr__loop, main__rom::p_instr_return, main__rom::p_X_const, main__rom::p_X_read_free, main__rom::p_read_X_A, main__rom::p_read_X_CNT, main__rom::p_read_X_pc];

View File

@@ -399,7 +399,7 @@ fn none_value<'a, T>() -> Value<'a, T> {
})
}
const BUILTINS: [(&str, BuiltinFunction); 21] = [
const BUILTINS: [(&str, BuiltinFunction); 20] = [
("std::array::len", BuiltinFunction::ArrayLen),
("std::check::panic", BuiltinFunction::Panic),
("std::convert::expr", BuiltinFunction::ToExpr),
@@ -424,12 +424,14 @@ const BUILTINS: [(&str, BuiltinFunction); 21] = [
("std::prover::degree", BuiltinFunction::Degree),
("std::prover::eval", BuiltinFunction::Eval),
("std::prover::try_eval", BuiltinFunction::TryEval),
("std::prover::get_input", BuiltinFunction::GetInput),
(
"std::prover::get_input_from_channel",
BuiltinFunction::GetInputFromChannel,
"std::prover::input_from_channel",
BuiltinFunction::InputFromChannel,
),
(
"std::prover::output_to_channel",
BuiltinFunction::OutputToChannel,
),
("std::prover::output_byte", BuiltinFunction::OutputByte),
];
#[derive(Clone, Copy, Debug)]
@@ -475,12 +477,10 @@ pub enum BuiltinFunction {
Eval,
/// std::prover::try_eval: expr -> std::prelude::Option<fe>, evaluates an expression on the current row
TryEval,
/// std::prover::get_input: int -> fe, returns the value of a prover-provided and uncommitted input
GetInput,
/// std::prover::get_input_from_channel: int, int -> fe, returns the value of a prover-provided and uncommitted input from a certain channel
GetInputFromChannel,
/// std::prover::output_byte: int, int -> (), outputs a byte to a file descriptor
OutputByte,
/// std::prover::input_from_channel: int, int -> fe, returns the value of a prover-provided and uncommitted input from a certain channel
InputFromChannel,
/// std::prover::output_to_channel: int, fe -> (), outputs a field element to an output channel
OutputToChannel,
}
impl<'a, T: Display> Display for Value<'a, T> {
@@ -761,13 +761,7 @@ pub trait SymbolLookup<'a, T: FieldElement> {
))
}
fn get_input(&mut self, _index: usize) -> Result<Arc<Value<'a, T>>, EvalError> {
Err(EvalError::Unsupported(
"Tried to get input outside of prover function.".to_string(),
))
}
fn get_input_from_channel(
fn input_from_channel(
&mut self,
_channel: u32,
_index: usize,
@@ -777,9 +771,9 @@ pub trait SymbolLookup<'a, T: FieldElement> {
))
}
fn output_byte(&mut self, _fd: u32, _byte: u8) -> Result<(), EvalError> {
fn output_to_channel(&mut self, _channel: u32, _elem: T) -> Result<(), EvalError> {
Err(EvalError::Unsupported(
"Tried to output byte outside of prover function.".to_string(),
"Tried to output to channel outside of prover function.".to_string(),
))
}
}
@@ -1431,9 +1425,8 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
BuiltinFunction::AtNextStage => 1,
BuiltinFunction::Eval => 1,
BuiltinFunction::TryEval => 1,
BuiltinFunction::GetInput => 1,
BuiltinFunction::GetInputFromChannel => 2,
BuiltinFunction::OutputByte => 2,
BuiltinFunction::InputFromChannel => 2,
BuiltinFunction::OutputToChannel => 2,
};
if arguments.len() != params {
@@ -1520,14 +1513,7 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
symbols.provide_value(col, row, value)?;
Value::Tuple(vec![]).into()
}
BuiltinFunction::GetInput => {
let index = arguments.pop().unwrap();
let Value::Integer(index) = index.as_ref() else {
panic!()
};
symbols.get_input(usize::try_from(index).unwrap())?
}
BuiltinFunction::GetInputFromChannel => {
BuiltinFunction::InputFromChannel => {
let index = arguments.pop().unwrap();
let channel = arguments.pop().unwrap();
let Value::Integer(index) = index.as_ref() else {
@@ -1536,18 +1522,21 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
let Value::Integer(channel) = channel.as_ref() else {
panic!()
};
symbols.get_input_from_channel(
symbols.input_from_channel(
u32::try_from(channel).unwrap(),
usize::try_from(index).unwrap(),
)?
}
BuiltinFunction::OutputByte => {
let byte = arguments.pop().unwrap();
let fd = arguments.pop().unwrap();
let (Value::Integer(fd), Value::Integer(byte)) = (fd.as_ref(), byte.as_ref()) else {
BuiltinFunction::OutputToChannel => {
let elem = arguments.pop().unwrap();
let channel = arguments.pop().unwrap();
let Value::Integer(channel) = channel.as_ref() else {
panic!()
};
symbols.output_byte(u32::try_from(fd).unwrap(), u8::try_from(byte).unwrap())?;
symbols.output_to_channel(
u32::try_from(channel).unwrap(),
elem.try_to_field_element().unwrap(),
)?;
Value::Tuple(vec![]).into()
}
BuiltinFunction::SetHint => {

View File

@@ -162,9 +162,8 @@ lazy_static! {
("std::prelude::set_hint", FunctionKind::Constr),
("std::prover::eval", FunctionKind::Query),
("std::prover::try_eval", FunctionKind::Query),
("std::prover::get_input", FunctionKind::Query),
("std::prover::get_input_from_channel", FunctionKind::Query),
("std::prover::output_byte", FunctionKind::Query),
("std::prover::input_from_channel", FunctionKind::Query),
("std::prover::output_to_channel", FunctionKind::Query),
]
.into_iter()
.collect();

View File

@@ -64,12 +64,8 @@ lazy_static! {
("", "expr -> std::prelude::Option<fe>")
),
("std::prover::provide_value", ("", "expr, int, fe -> ()")),
("std::prover::get_input", ("", "int -> fe")),
(
"std::prover::get_input_from_channel",
("", "int, int -> fe")
),
("std::prover::output_byte", ("", "int, int -> ()"))
("std::prover::input_from_channel", ("", "int, int -> fe")),
("std::prover::output_to_channel", ("", "int, fe -> ()"))
]
.into_iter()
.map(|(name, (vars, ty))| { (name.to_string(), parse_type_scheme(vars, ty)) })

View File

@@ -61,9 +61,9 @@ namespace T(65536);
T::A' = T::first_step' * 0 + T::reg_write_X_A * T::X + (1 - (T::first_step' + T::reg_write_X_A)) * T::A;
col witness X_free_value;
std::prelude::set_hint(T::X_free_value, query |_| match std::prover::eval(T::pc) {
0 => std::prelude::Query::Input(1),
3 => std::prelude::Query::Input(std::convert::int::<fe>(std::prover::eval(T::CNT) + 1)),
7 => std::prelude::Query::Input(0),
0 => std::prelude::Query::Input(0, 2),
3 => std::prelude::Query::Input(0, std::convert::int::<fe>(std::prover::eval(T::CNT) + 2)),
7 => std::prelude::Query::Input(0, 1),
_ => std::prelude::Query::None,
});
col fixed p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*;

View File

@@ -98,8 +98,8 @@ pub fn serde_data_to_query_callback<T: FieldElement>(
move |query: &str| -> Result<Option<T>, String> {
let (id, data) = parse_query(query)?;
match id {
"DataIdentifier" => {
let [index, cb_channel] = data[..] else {
"Input" => {
let [cb_channel, index] = data[..] else {
panic!()
};
let cb_channel = cb_channel
@@ -131,8 +131,8 @@ pub fn dict_data_to_query_callback<T: FieldElement>(
move |query: &str| -> Result<Option<T>, String> {
let (id, data) = parse_query(query)?;
match id {
"DataIdentifier" => {
let [index, cb_channel] = data[..] else {
"Input" => {
let [cb_channel, index] = data[..] else {
panic!()
};
let cb_channel = cb_channel
@@ -153,22 +153,6 @@ pub fn dict_data_to_query_callback<T: FieldElement>(
index => elems[index - 1],
}))
}
"Input" => {
assert_eq!(data.len(), 1);
let index = data[0]
.parse::<usize>()
.map_err(|e| format!("Error parsing index: {e})"))?;
let Some(elems) = dict.get(&0) else {
return Err("No prover inputs given".to_string());
};
elems
.get(index)
.cloned()
.map(Some)
.ok_or_else(|| format!("Index out of bounds: {index}"))
}
_ => Err(format!("Unsupported query: {query}")),
}
}

View File

@@ -7,11 +7,11 @@ use powdr_riscv_syscalls::Syscall;
use alloc::vec;
use alloc::vec::Vec;
/// Reads a single u32 from the file descriptor fd.
pub fn read_u32(fd: u32) -> u32 {
/// A single u32 from input channel 0.
pub fn read_u32(idx: u32) -> u32 {
let mut value: u32;
unsafe {
asm!("ecall", lateout("a0") value, in("a0") fd, in("t0") u32::from(Syscall::Input));
asm!("ecall", lateout("a0") value, in("a0") 0, in("a1") idx + 1, in("t0") u32::from(Syscall::Input));
}
value
}
@@ -20,7 +20,7 @@ pub fn read_u32(fd: u32) -> u32 {
pub fn read_slice(fd: u32, data: &mut [u32]) {
for (i, d) in data.iter_mut().enumerate() {
unsafe {
asm!("ecall", lateout("a0") *d, in("a0") fd, in("a1") (i+1) as u32, in("t0") u32::from(Syscall::DataIdentifier))
asm!("ecall", lateout("a0") *d, in("a0") fd, in("a1") (i+1) as u32, in("t0") u32::from(Syscall::Input))
};
}
}
@@ -29,7 +29,7 @@ pub fn read_slice(fd: u32, data: &mut [u32]) {
pub fn read_data_len(fd: u32) -> usize {
let mut out: u32;
unsafe {
asm!("ecall", lateout("a0") out, in("a0") fd, in("a1") 0, in("t0") u32::from(Syscall::DataIdentifier))
asm!("ecall", lateout("a0") out, in("a0") fd, in("a1") 0, in("t0") u32::from(Syscall::Input))
};
out as usize
}

View File

@@ -46,8 +46,7 @@ macro_rules! syscalls {
// Generate `Syscall` enum with supported syscalls and their numbers.
syscalls!(
(0, Input, "input"),
(1, DataIdentifier, "data_identifier"),
(1, Input, "input"),
(2, Output, "output"),
(3, PoseidonGL, "poseidon_gl"),
(4, Affine256, "affine_256"),

View File

@@ -234,21 +234,13 @@ impl Runtime {
// Base syscalls
r.add_syscall(
// TODO this is a quite inefficient way of getting prover inputs.
// We need to be able to access the register memory within PIL functions.
Syscall::Input,
[
// TODO this is a quite inefficient way of getting prover inputs.
// We need to be able to access the register memory within PIL functions.
"query_arg_1 <== get_reg(10);",
"set_reg 10, ${ std::prelude::Query::Input(std::convert::int(std::prover::eval(query_arg_1))) };",
],
);
r.add_syscall(
Syscall::DataIdentifier,
[
"query_arg_1 <== get_reg(10);",
"query_arg_2 <== get_reg(11);",
"set_reg 10, ${ std::prelude::Query::DataIdentifier(std::convert::int(std::prover::eval(query_arg_2)), std::convert::int(std::prover::eval(query_arg_1))) };",
"set_reg 10, ${ std::prelude::Query::Input(std::convert::int(std::prover::eval(query_arg_1)), std::convert::int(std::prover::eval(query_arg_2))) };",
]
);
@@ -259,7 +251,7 @@ impl Runtime {
[
"query_arg_1 <== get_reg(10);",
"query_arg_2 <== get_reg(11);",
"set_reg 0, ${ std::prelude::Query::Output(std::convert::int(std::prover::eval(query_arg_1)), std::convert::int(std::prover::eval(query_arg_2))) };"
"set_reg 0, ${ std::prelude::Query::Output(std::convert::int(std::prover::eval(query_arg_1)), std::prover::eval(query_arg_2)) };"
]
);

View File

@@ -36,15 +36,13 @@ enum SelectedExprs {
/// The return type of a prover query function.
enum Query {
/// Query a prover input element by index.
Input(int),
/// Writes a byte (second argument) to a file descriptor (first argument).
/// fd 1 is stdout, fd 2 is stderr.
Output(int, int),
/// Generate a hint to fill a witness column with.
Hint(fe),
/// Query a prover input element by index and data id.
DataIdentifier(int, int),
/// Query a prover input (field element) by channel id and index.
Input(int, int),
/// Writes a field element (second argument) to an output channel (first argument).
/// Channel 1 is stdout, 2 is stderr.
Output(int, fe),
/// This value is not (additionally) constrained by the query.
None,
}

View File

@@ -20,23 +20,19 @@ let provide_if_unknown: expr, int, (-> fe) -> () = query |column, row, f| match
_ => (),
};
/// Retrieves a byte from a prover-provided (untrusted and not committed) input.
/// The parameter is the index of the byte.
let get_input: int -> fe = [];
/// Retrieves a field element from a prover-provided (untrusted and not committed) input channel.
/// The parameters are the channel id and the index in the channel.
/// Index zero is the length of the channel (number of bytes) and index 1 is the first element.
let input_from_channel: int, int -> fe = [];
/// Retrieves a byte from a prover-provided (untrusted and not committed) input channel.
/// The parameters are the index of the channel and the index in the channel.
/// Index zero is the length of the channel (number of bytes) and index 1 is the first byte.
let get_input_from_channel: int, int -> fe = [];
/// Outputs a byte to a file descriptor.
let output_byte: int, int -> () = [];
/// Writes a field element to the given output channel.
/// The first parameter is the channel id, the second is the element to write.
let output_to_channel: int, fe -> () = [];
let handle_query: expr, int, std::prelude::Query -> () = query |column, row, v| match v {
Query::Hint(h) => provide_if_unknown(column, row, || h),
Query::Input(i) => provide_if_unknown(column, row, || get_input(i)),
Query::DataIdentifier(i, j) => provide_if_unknown(column, row, || get_input_from_channel(i, j)),
Query::Output(fd, b) => provide_if_unknown(column, row, || { output_byte(fd, b); 0 }),
Query::Input(i, j) => provide_if_unknown(column, row, || input_from_channel(i, j)),
Query::Output(channel, e) => provide_if_unknown(column, row, || { output_to_channel(channel, e); 0 }),
Query::None => (),
};

View File

@@ -25,9 +25,9 @@ machine BitAccess with degree: 32 {
instr assert_zero X { XIsZero = 1 }
function main {
B <=X= ${ std::prelude::Query::Input(0) };
B <=X= ${ std::prelude::Query::Input(0, 1) };
wrap B + 0xffffffec, A;
assert_zero A;
return;
}
}
}

View File

@@ -20,7 +20,7 @@ machine HelloWorld with degree: 8 {
// the main function assigns the first prover input to A, increments it, decrements it, and loops forever
function main {
A <=X= ${ std::prelude::Query::Input(0) };
A <=X= ${ std::prelude::Query::Input(0, 1) };
A <== incr(A);
A <== decr(A);
assert_zero A;

View File

@@ -26,7 +26,7 @@ machine FunctionalInstructions with degree: 32 {
instr assert_zero X { XIsZero = 1 }
function main {
B <=X= ${ std::prelude::Query::Input(0) };
B <=X= ${ std::prelude::Query::Input(0, 1) };
A <=X= wrap(B + 0xffffffec);
assert_zero A;
return;

View File

@@ -13,9 +13,9 @@ machine MultiAssign with degree: 8 {
instr assert_zero X { XIsZero = 1 }
function main {
A <=X= ${ std::prelude::Query::Input(0) };
A <=X= ${ std::prelude::Query::Input(0, 1) };
A <=Y= A - 7;
assert_zero A;
return;
}
}
}

View File

@@ -40,14 +40,14 @@ machine Palindrome with degree: 32 {
function main {
// TOOD somehow this is not properly resolved here without "std::prover::"
CNT <=X= ${ Query::Input(0) };
CNT <=X= ${ Query::Input(0, 1) };
ADDR <=X= 0;
mstore CNT;
store_values:
jmpz CNT, check_start;
ADDR <=X= CNT;
mstore ${ Query::Input(int(std::prover::eval(CNT))) };
mstore ${ Query::Input(0, int(std::prover::eval(CNT)) + 1) };
CNT <=X= CNT - 1;
jmp store_values;

View File

@@ -35,18 +35,18 @@ machine Main with degree: 16 {
instr assert_zero X { XIsZero = 1 }
function main {
CNT <=X= ${ Query::Input(1) };
CNT <=X= ${ Query::Input(0, 2) };
start:
jmpz CNT, check;
A <=X= A + ${ Query::Input(std::convert::int(std::prover::eval(CNT) + 1)) };
A <=X= A + ${ Query::Input(0, std::convert::int(std::prover::eval(CNT) + 2)) };
// Could use "CNT <=X= CNT - 1", but that would need X.
dec_CNT;
jmp start;
check:
A <=X= A - ${ Query::Input(0) };
A <=X= A - ${ Query::Input(0, 1) };
assert_zero A;
return;
}
}
}

View File

@@ -12,7 +12,7 @@ machine Square with degree: 8 {
}
function main {
A <=X= ${ std::prelude::Query::Input(0) };
A <=X= ${ std::prelude::Query::Input(0, 1) };
A <== square(A);
}
}

View File

@@ -2,7 +2,7 @@ let N: int = 4;
namespace std::prover(N);
enum Query {
Input(int),
Input(int, int),
None,
}
@@ -15,9 +15,9 @@ namespace Sum(N);
col witness input(i) query match i {
// A non-exhaustive match statement is the only way to return "None"
0 => std::prelude::Query::Input(0),
1 => std::prelude::Query::Input(1),
2 => std::prelude::Query::Input(2),
0 => std::prelude::Query::Input(0, 1),
1 => std::prelude::Query::Input(0, 2),
2 => std::prelude::Query::Input(0, 3),
// No response in the case of i == 3
_ => std::prelude::Query::None,
};

View File

@@ -3,10 +3,9 @@
namespace std::prover;
enum Query {
Input(int),
Output(int, int),
Hint(fe),
DataIdentifier(int, int),
Input(int, int),
Output(int, int),
None,
}
namespace main(256);

View File

@@ -5,14 +5,14 @@ namespace std::convert(16);
namespace std::prover(16);
enum Query {
Input(int)
Input(int, int),
}
namespace Quad(N);
col fixed id(i) { i };
col fixed double(i) { i * 2 };
col witness input(i) query std::prelude::Query::Input(i);
col witness input(i) query std::prelude::Query::Input(0, i + 1);
col witness wdouble;
col witness quadruple;