diff --git a/std/binary.asm b/std/binary.asm index 8dba7677c..897d212d7 100644 --- a/std/binary.asm +++ b/std/binary.asm @@ -1,4 +1,5 @@ use std::convert::int; +use std::utils::cross_product; machine Binary(latch, operation_id) { @@ -15,15 +16,20 @@ machine Binary(latch, operation_id) { col fixed latch(i) { if (i % 4) == 3 { 1 } else { 0 } }; col fixed FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; - col fixed P_A(i) { i % 256 }; - col fixed P_B(i) { (i >> 8) % 256 }; - col fixed P_operation(i) { (i / (256 * 256)) % 3 }; + // TOOD would be nice with destructuring assignment for arrays. + let inputs: (int -> int)[] = cross_product([3, 256, 256]); + let a = inputs[2]; + let b = inputs[1]; + let op = inputs[0]; + col fixed P_A(i) { a(i) }; + col fixed P_B(i) { b(i) }; + col fixed P_operation(i) { op(i)}; col fixed P_C(i) { - match P_operation(i) { - 0 => int(P_A(i)) & int(P_B(i)), - 1 => int(P_A(i)) | int(P_B(i)), - 2 => int(P_A(i)) ^ int(P_B(i)), - } & 0xff + match op(i) { + 0 => a(i) & b(i), + 1 => a(i) | b(i), + 2 => a(i) ^ b(i), + } }; col witness A_byte; diff --git a/std/utils.asm b/std/utils.asm index 0a54aa2f9..596107402 100644 --- a/std/utils.asm +++ b/std/utils.asm @@ -23,3 +23,22 @@ let unchanged_until = |c, latch| (c' - c) * (1 - latch) = 0; /// Evaluates to a constraint that forces `c` to be either 0 or 1. let force_bool: expr -> constr = |c| c * (1 - c) = 0; + +/// Returns an array of functions such that the range of the `i`th function is exactly the +/// first `size[i]` numbers (i.e. `0` until `size[i] - 1`, inclusive), such that all combinations +/// of values of these functions appear as combined outputs. +/// Each of the functions cycles through its values, advancing to the next number whenever the +/// next function has completed a cycle (or always advancing if it is the last function). +/// This function is useful for combined range checks or building the inputs for function +/// that is implemented in a lookup. +/// See binary.asm for an example. +let cross_product: int[] -> (int -> int)[] = |sizes| cross_product_internal(1, std::array::len(sizes), sizes); + +let cross_product_internal: int, int[] -> (int -> int)[] = |cycle_len, len, sizes| + if len == 0 { + // We could assert here that the degree is at least `cycle_len` + [] + } else { + cross_product_internal(cycle_len * sizes[len - 1], len - 1, sizes) + + [|i| (i / cycle_len) % (sizes[len - 1])] + };