From 1a140edbcb186567081bb4a4b705b17a72300a2b Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Fri, 15 Sep 2023 10:35:11 +0200 Subject: [PATCH] Change register updates to have degree 2, rewrite read-only update (#565) * change pc update to have degree 2, rewrite readonly update * use intermediate pols --------- Co-authored-by: Leo Alt --- asm_to_pil/src/vm_to_constrained.rs | 105 ++++++++++++++++++++++++---- linker/src/lib.rs | 19 +++-- 2 files changed, 103 insertions(+), 21 deletions(-) diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 28fe9fa09..100b39b00 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -126,29 +126,50 @@ impl ASMPILConverter { reg.update_expression().map(|rhs| { let lhs = next_reference(name); use RegisterTy::*; - let (lhs, rhs) = match reg.ty { + match reg.ty { // Force pc to zero on first row. - Pc => ( - lhs, - build_mul( - build_sub(build_number(1u64), next_reference("first_step")), - rhs, - ), - ), + Pc => { + // introduce an intermediate witness polynomial to keep the degree of polynomial identities at 2 + // this may not be optimal for backends which support higher degree constraints + let pc_update_name = format!("{}_update", name); + + vec![ + PilStatement::PolynomialDefinition( + 0, + pc_update_name.to_string(), + rhs, + ), + PilStatement::PolynomialIdentity( + 0, + build_sub( + lhs, + build_mul( + build_sub( + build_number(1u64), + next_reference("first_step"), + ), + direct_reference(pc_update_name), + ), + ), + ), + ] + } // Unconstrain read-only registers when calling `_reset` ReadOnly => { let not_reset = build_sub(build_number(1u64), direct_reference("instr__reset")); - (build_mul(not_reset.clone(), lhs), build_mul(not_reset, rhs)) + vec![PilStatement::PolynomialIdentity( + 0, + build_mul(not_reset, build_sub(lhs, rhs)), + )] } - // - _ => (lhs, rhs), - }; - - PilStatement::PolynomialIdentity(0, build_sub(lhs, rhs)) + _ => { + vec![PilStatement::PolynomialIdentity(0, build_sub(lhs, rhs))] + } + } }) }) - .collect::>(), + .flatten(), ); for batch in rom.unwrap().statements.into_iter_batches() { @@ -369,6 +390,10 @@ impl ASMPILConverter { (Some(var), expr) => { let reference = direct_reference(&instruction_flag); + // reduce the update to linear by introducing intermediate variables + let expr = self + .linearize(&format!("{instruction_flag}_{var}_update"), expr); + self.registers .get_mut(&var) .unwrap() @@ -864,6 +889,56 @@ impl ASMPILConverter { fn return_instruction(&self) -> ast::asm_analysis::Instruction { return_instruction(self.output_count, self.pc_name.as_ref().unwrap()) } + + /// Return an expression of degree at most 1 whose value matches that of `expr` + /// Intermediate witness columns can be introduced, with names starting with `prefix` optionally followed by a suffix + /// Suffixes are defined as follows: "", "_1", "_2", "_3" etc + fn linearize(&mut self, prefix: &str, expr: Expression) -> Expression { + self.linearize_rec(prefix, 0, expr).1 + } + + fn linearize_rec( + &mut self, + prefix: &str, + counter: usize, + expr: Expression, + ) -> (usize, Expression) { + match expr { + Expression::BinaryOperation(left, operator, right) => match operator { + BinaryOperator::Add => { + let (counter, left) = self.linearize_rec(prefix, counter, *left); + let (counter, right) = self.linearize_rec(prefix, counter, *right); + (counter, build_add(left, right)) + } + BinaryOperator::Sub => { + let (counter, left) = self.linearize_rec(prefix, counter, *left); + let (counter, right) = self.linearize_rec(prefix, counter, *right); + (counter, build_sub(left, right)) + } + BinaryOperator::Mul => { + // if we have a quadratic term, we linearize each factor and introduce an intermediate variable for the product + let (counter, left) = self.linearize_rec(prefix, counter, *left); + let (counter, right) = self.linearize_rec(prefix, counter, *right); + let intermediate_name = format!( + "{prefix}{}", + if counter == 0 { + "".to_string() + } else { + format!("_{}", counter) + } + ); + self.pil.push(PilStatement::PolynomialDefinition( + 0, + intermediate_name.to_string(), + build_mul(left, right), + )); + (counter + 1, direct_reference(intermediate_name)) + } + op => unimplemented!("{op} is not supported when linearizing"), + }, + expr => (counter, expr), + } + } } struct Register { diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 16f857e2b..29fbb3342 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -221,7 +221,8 @@ pol commit instr__reset; pol commit instr__loop; pol commit instr_return; pol constant first_step = [1] + [0]*; -pc' = ((1 - first_step') * ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1)))); +pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); +pc' = ((1 - first_step') * pc_update); pol constant p_line = [0, 1, 2] + [2]*; pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*; pol constant p_instr__loop = [0, 0, 1] + [1]*; @@ -278,7 +279,8 @@ pol commit read_Y_pc; Y = ((((read_Y_A * A) + (read_Y_pc * pc)) + Y_const) + (Y_read_free * Y_free_value)); pol constant first_step = [1] + [0]*; A' = ((((reg_write_X_A * X) + (reg_write_Y_A * Y)) + (instr__reset * 0)) + ((1 - ((reg_write_X_A + reg_write_Y_A) + instr__reset)) * A)); -pc' = ((1 - first_step') * ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1)))); +pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); +pc' = ((1 - first_step') * pc_update); pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol commit X_free_value(i) query match pc { }; pol commit Y_free_value(i) query match pc { }; @@ -328,8 +330,9 @@ pol commit read__output_0_pc; pol commit read__output_0__input_0; _output_0 = ((((read__output_0_pc * pc) + (read__output_0__input_0 * _input_0)) + _output_0_const) + (_output_0_read_free * _output_0_free_value)); pol constant first_step = [1] + [0]*; -((1 - instr__reset) * _input_0') = ((1 - instr__reset) * _input_0); -pc' = ((1 - first_step') * ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1)))); +((1 - instr__reset) * (_input_0' - _input_0)) = 0; +pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); +pc' = ((1 - first_step') * pc_update); pol constant p_line = [0, 1, 2, 3, 4, 5] + [5]*; pol commit _output_0_free_value(i) query match pc { }; pol constant p__output_0_const = [0, 0, 0, 0, 1, 0] + [0]*; @@ -378,6 +381,8 @@ pol commit reg_write_X_CNT; pol commit CNT; pol commit instr_jmpz; pol commit instr_jmpz_param_l; +pol instr_jmpz_pc_update = (XIsZero * instr_jmpz_param_l); +pol instr_jmpz_pc_update_1 = ((1 - XIsZero) * (pc + 1)); pol commit instr_jmp; pol commit instr_jmp_param_l; pol commit instr_dec_CNT; @@ -396,7 +401,8 @@ X = (((((read_X_A * A) + (read_X_CNT * CNT)) + (read_X_pc * pc)) + X_const) + (X pol constant first_step = [1] + [0]*; A' = (((reg_write_X_A * X) + (instr__reset * 0)) + ((1 - (reg_write_X_A + instr__reset)) * A)); CNT' = ((((reg_write_X_CNT * X) + (instr_dec_CNT * (CNT - 1))) + (instr__reset * 0)) + ((1 - ((reg_write_X_CNT + instr_dec_CNT) + instr__reset)) * CNT)); -pc' = ((1 - first_step') * ((((((instr_jmpz * ((XIsZero * instr_jmpz_param_l) + ((1 - XIsZero) * (pc + 1)))) + (instr_jmp * instr_jmp_param_l)) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((((instr_jmpz + instr_jmp) + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1)))); +pol pc_update = ((((((instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1)) + (instr_jmp * instr_jmp_param_l)) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((((instr_jmpz + instr_jmp) + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1))); +pc' = ((1 - first_step') * pc_update); pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + [10]*; pol commit X_free_value(i) query match pc { 2 => ("input", 1), 4 => ("input", (CNT + 1)), 7 => ("input", 0), }; pol constant p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; @@ -472,7 +478,8 @@ pol commit instr__loop; pol commit instr_return; pol constant first_step = [1] + [0]*; fp' = ((((instr_inc_fp * (fp + instr_inc_fp_param_amount)) + (instr_adjust_fp * (fp + instr_adjust_fp_param_amount))) + (instr__reset * 0)) + ((1 - ((instr_inc_fp + instr_adjust_fp) + instr__reset)) * fp)); -pc' = ((1 - first_step') * (((((instr_adjust_fp * label) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - (((instr_adjust_fp + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1)))); +pol pc_update = (((((instr_adjust_fp * label) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - (((instr_adjust_fp + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1))); +pc' = ((1 - first_step') * pc_update); pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0, 0] + [0]*; pol constant p_instr__loop = [0, 0, 0, 0, 1] + [1]*;