From e644be4d3903e09bf72c827de60011e04038feb6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 14 Feb 2024 16:00:03 +0100 Subject: [PATCH] Update utils. --- std/binary.asm | 6 +++--- std/shift.asm | 20 +++++++++++++------- std/split/split_bn254.asm | 13 +++++++++---- std/split/split_gl.asm | 13 +++++++++---- std/utils.asm | 12 ++++++------ 5 files changed, 40 insertions(+), 24 deletions(-) diff --git a/std/binary.asm b/std/binary.asm index 897d212d7..ff7b398ac 100644 --- a/std/binary.asm +++ b/std/binary.asm @@ -17,10 +17,10 @@ machine Binary(latch, operation_id) { col fixed FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; // TOOD would be nice with destructuring assignment for arrays. - let inputs: (int -> int)[] = cross_product([3, 256, 256]); - let a = inputs[2]; + let inputs: (int -> int)[] = cross_product([256, 256, 3]); + let a = inputs[0]; let b = inputs[1]; - let op = inputs[0]; + let op = inputs[2]; col fixed P_A(i) { a(i) }; col fixed P_B(i) { b(i) }; col fixed P_operation(i) { op(i)}; diff --git a/std/shift.asm b/std/shift.asm index fc509155c..8f70cc929 100644 --- a/std/shift.asm +++ b/std/shift.asm @@ -1,4 +1,5 @@ use std::utils::unchanged_until; +use std::utils::cross_product; use std::convert::int; machine Shift(latch, operation_id) { @@ -14,14 +15,19 @@ machine Shift(latch, operation_id) { col fixed FACTOR_ROW(i) { (i + 1) % 4 }; col fixed FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; - col fixed P_A(i) { i % 256 }; - col fixed P_B(i) { (i / 256) % 32 }; - col fixed P_ROW(i) { (i / (256 * 32)) % 4 }; - col fixed P_operation(i) { (i / (256 * 32 * 4)) % 2 }; + let inputs = cross_product([256, 32, 4, 2]); + let a: int -> int = inputs[0]; + let b: int -> int = inputs[1]; + let row: int -> int = inputs[2]; + let op: int -> int = inputs[3]; + let P_A: col = a; + let P_B: col = b; + let P_ROW: col = row; + let P_operation: col = op; col fixed P_C(i) { - match P_operation(i) { - 0 => (int(P_A(i)) << (int(P_B(i)) + (int(P_ROW(i)) * 8))), - 1 => (int(P_A(i)) << (int(P_ROW(i)) * 8)) >> int(P_B(i)), + match op(i) { + 0 => a(i) << (b(i) + (row(i) * 8)), + 1 => (a(i) << (row(i) * 8)) >> b(i), } & 0xffffffff }; diff --git a/std/split/split_bn254.asm b/std/split/split_bn254.asm index 86af665e0..d8e39683a 100644 --- a/std/split/split_bn254.asm +++ b/std/split/split_bn254.asm @@ -1,3 +1,5 @@ +use std::utils::cross_product; + // Splits an arbitrary field element into 8 u32s (in little endian order), on the BN254 field. machine SplitBN254(RESET, _) { @@ -60,10 +62,13 @@ machine SplitBN254(RESET, _) { col fixed BYTES_MAX = [0x00, 0x00, 0xf0, 0x93, 0xf5, 0xe1, 0x43, 0x91, 0x70, 0xb9, 0x79, 0x48, 0xe8, 0x33, 0x28, 0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, 0x64, 0x30, 0x00]*; // Byte comparison block machine - col fixed P_A(i) { i % 256 }; - col fixed P_B(i) { (i >> 8) % 256 }; - col fixed P_LT(i) { if std::convert::int(P_A(i)) < std::convert::int(P_B(i)) { 1 } else { 0 } }; - col fixed P_GT(i) { if std::convert::int(P_A(i)) > std::convert::int(P_B(i)) { 1 } else { 0 } }; + let compare_inputs = cross_product([256, 256]); + let a = compare_inputs[1]; + let b = compare_inputs[0]; + let P_A: col = a; + let P_B: col = b; + col fixed P_LT(i) { if a(i) < b(i) { 1 } else { 0 } }; + col fixed P_GT(i) { if a(i) > b(i) { 1 } else { 0 } }; // Compare the current byte with the corresponding byte of the maximum value. col witness lt; diff --git a/std/split/split_gl.asm b/std/split/split_gl.asm index 8752d6b31..57f77d1fa 100644 --- a/std/split/split_gl.asm +++ b/std/split/split_gl.asm @@ -1,3 +1,5 @@ +use std::utils::cross_product; + // Splits an arbitrary field element into two u32s, on the Goldilocks field. machine SplitGL(RESET, _) { @@ -56,10 +58,13 @@ machine SplitGL(RESET, _) { col fixed BYTES_MAX = [0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0]*; // Byte comparison block machine - col fixed P_A(i) { i % 256 }; - col fixed P_B(i) { (i >> 8) % 256 }; - col fixed P_LT(i) { if std::convert::int(P_A(i)) < std::convert::int(P_B(i)) { 1 } else { 0 } }; - col fixed P_GT(i) { if std::convert::int(P_A(i)) > std::convert::int(P_B(i)) { 1 } else { 0 } }; + let inputs = cross_product([256, 256]); + let a: int -> int = inputs[0]; + let b: int -> int = inputs[1]; + let P_A: col = a; + let P_B: col = b; + col fixed P_LT(i) { if a(i) < b(i) { 1 } else { 0 } }; + col fixed P_GT(i) { if a(i) > b(i) { 1 } else { 0 } }; // Compare the current byte with the corresponding byte of the maximum value. col witness lt; diff --git a/std/utils.asm b/std/utils.asm index 596107402..fbdff6205 100644 --- a/std/utils.asm +++ b/std/utils.asm @@ -28,17 +28,17 @@ let force_bool: expr -> constr = |c| c * (1 - c) = 0; /// first `size[i]` numbers (i.e. `0` until `size[i] - 1`, inclusive), such that all combinations /// of values of these functions appear as combined outputs. /// Each of the functions cycles through its values, advancing to the next number whenever the -/// next function has completed a cycle (or always advancing if it is the last function). +/// previous function has completed a cycle (or always advancing if it is the first function). /// This function is useful for combined range checks or building the inputs for function /// that is implemented in a lookup. /// See binary.asm for an example. -let cross_product: int[] -> (int -> int)[] = |sizes| cross_product_internal(1, std::array::len(sizes), sizes); +let cross_product: int[] -> (int -> int)[] = |sizes| cross_product_internal(1, 0, sizes); -let cross_product_internal: int, int[] -> (int -> int)[] = |cycle_len, len, sizes| - if len == 0 { +let cross_product_internal: int, int, int[] -> (int -> int)[] = |cycle_len, pos, sizes| + if pos >= std::array::len(sizes) { // We could assert here that the degree is at least `cycle_len` [] } else { - cross_product_internal(cycle_len * sizes[len - 1], len - 1, sizes) - + [|i| (i / cycle_len) % (sizes[len - 1])] + [|i| (i / cycle_len) % sizes[pos]] + + cross_product_internal(cycle_len * sizes[pos], pos + 1, sizes) };