diff --git a/src/zkas/constants.rs b/src/zkas/constants.rs index a02fc70c8..970e2d5a6 100644 --- a/src/zkas/constants.rs +++ b/src/zkas/constants.rs @@ -17,22 +17,46 @@ */ /// Maximum allowed k param (circuit rows = 2^k) -pub const MAX_K: u32 = 16; +pub(super) const MAX_K: u32 = 16; /// Maximum allowed namespace length in bytes -pub const MAX_NS_LEN: usize = 32; +pub(super) const MAX_NS_LEN: usize = 32; /// Minimum size allowed for a syntactically valid ZkBinary /// MAGIC_BYTES.length = 4; /// `k = ##;` = 6 (because the current upper-limit for k is a two-digit number /// Therefore 4 + 6 = 10 is the minimum size -pub const MIN_BIN_SIZE: usize = 10; +pub(super) const MIN_BIN_SIZE: usize = 10; + +/// Maximum allowed binary size (1M) +pub(super) const MAX_BIN_SIZE: usize = 1024 * 1024; + +/// Maximum number of constants allowed +pub(super) const MAX_CONSTANTS: usize = 1024; + +/// Maximum number of literals allowed +pub(super) const MAX_LITERALS: usize = 4096; + +/// Maximum number of witnesses allowed +pub(super) const MAX_WITNESSES: usize = 4096; + +/// Maximum number of opcodes allowed +pub(super) const MAX_OPCODES: usize = 4096; + +/// Maximum number of arguments per opcode +pub(super) const MAX_ARGS_PER_OPCODE: usize = 256; + +/// Maximum total heap size (constants + witnesses + assigned variables) +pub(super) const MAX_HEAP_SIZE: usize = MAX_CONSTANTS + MAX_WITNESSES + MAX_OPCODES; + +/// Maximum string length for names +pub(super) const MAX_STRING_LEN: usize = 1024; /// Allowed fields for proofs -pub const ALLOWED_FIELDS: [&str; 1] = ["pallas"]; +pub(super) const ALLOWED_FIELDS: [&str; 1] = ["pallas"]; /// Maximum recursion depth for nested function calls -pub const MAX_RECURSION_DEPTH: usize = 16; +pub(super) const MAX_RECURSION_DEPTH: usize = 16; // Section markers in the binary format pub(super) const SECTION_CONSTANT: &[u8] = b".constant"; diff --git a/src/zkas/decoder.rs b/src/zkas/decoder.rs index d2f406ddc..70f6540bc 100644 --- a/src/zkas/decoder.rs +++ b/src/zkas/decoder.rs @@ -16,13 +16,14 @@ * along with this program. If not, see . */ -use darkfi_serial::{deserialize_partial, VarInt}; +use darkfi_serial::{deserialize_limited_partial, deserialize_partial, VarInt}; use super::{ compiler::MAGIC_BYTES, constants::{ - MAX_K, MAX_NS_LEN, MIN_BIN_SIZE, SECTION_CIRCUIT, SECTION_CONSTANT, SECTION_DEBUG, - SECTION_LITERAL, SECTION_WITNESS, + MAX_ARGS_PER_OPCODE, MAX_BIN_SIZE, MAX_CONSTANTS, MAX_HEAP_SIZE, MAX_K, MAX_LITERALS, + MAX_NS_LEN, MAX_OPCODES, MAX_STRING_LEN, MAX_WITNESSES, MIN_BIN_SIZE, SECTION_CIRCUIT, + SECTION_CONSTANT, SECTION_DEBUG, SECTION_LITERAL, SECTION_WITNESS, }, types::HeapType, LitType, Opcode, VarType, @@ -75,6 +76,33 @@ fn find_section(bytes: &[u8], section: &[u8]) -> Result { }) } +/// Validate that a count is within limits and reasonable for the remaining bytes +fn validate_count( + count: u64, + max: usize, + remaining_bytes: usize, + item_name: &str, +) -> Result { + let count = count as usize; + + if count > max { + return Err(ZkasErr(format!( + "{} count {} exceeds maximum allowed {}", + item_name, count, max + ))); + } + + // Sanity check: each item needs at least 1 byte + if count > remaining_bytes { + return Err(ZkasErr(format!( + "{} count {} exceeds remaining bytes {}", + item_name, count, remaining_bytes + ))); + } + + Ok(count) +} + struct SectionOffsets { constant: usize, literal: usize, @@ -153,6 +181,16 @@ impl ZkBinary { if bytes.len() < MIN_BIN_SIZE { return Err(ZkasErr("Not enough bytes".to_string())) } + + // Check max size to prevent decoding maliciously large binaries + if bytes.len() > MAX_BIN_SIZE { + return Err(ZkasErr(format!( + "Binary size {} exceeds maximum allowed {}", + bytes.len(), + MAX_BIN_SIZE + ))) + } + let magic_bytes = &bytes[0..4]; if magic_bytes != MAGIC_BYTES { return Err(ZkasErr("Magic bytes are incorrect".to_string())) @@ -169,12 +207,7 @@ impl ZkBinary { } // After the binary version and k, we're supposed to have the witness namespace - let (namespace, _): (String, _) = deserialize_partial(&bytes[9..])?; - - // Enforce a limit on the namespace string length - if namespace.len() > MAX_NS_LEN { - return Err(ZkasErr("Namespace too long".to_string())) - } + let (namespace, _) = deserialize_limited_partial::(&bytes[9..], MAX_NS_LEN)?; // =============== // Section parsing @@ -194,7 +227,97 @@ impl ZkBinary { }; } - Ok(Self { namespace, k, constants, literals, witnesses, opcodes, debug_info }) + let binary = Self { namespace, k, constants, literals, witnesses, opcodes, debug_info }; + + // Validate cross-references between sections + binary.validate()?; + + Ok(binary) + } + + /// Validate cross-references and consistency between sections. + /// This catches malicious binaries that pass individual section + /// parsing but have invalid references. + fn validate(&self) -> Result<()> { + // Calculate actual heap size: constants + witnesses + assigned vars + // Each opcode that produces a result adds one entry to the heap + let num_assignments = self + .opcodes + .iter() + .filter(|(op, _)| { + let (ret_types, _) = op.arg_types(); + !ret_types.is_empty() + }) + .count(); + + let heap_size = self.constants.len() + self.witnesses.len() + num_assignments; + + // Validate all heap references in opcodes + for (op_idx, (opcode, args)) in self.opcodes.iter().enumerate() { + // Calculate heap size at this point in execution + // (constants + witnesses + results from previous opcodes) + let prev_assignments = self.opcodes[..op_idx] + .iter() + .filter(|(op, _)| { + let (ret_types, _) = op.arg_types(); + !ret_types.is_empty() + }) + .count(); + let available_heap = self.constants.len() + self.witnesses.len() + prev_assignments; + + for (heap_type, heap_idx) in args { + match heap_type { + HeapType::Var => { + if *heap_idx >= available_heap { + return Err(ZkasErr(format!( + "Opcode {} references heap idx {} but only {} entries available", + opcode.name(), + heap_idx, + available_heap + ))); + } + } + HeapType::Lit => { + if *heap_idx >= self.literals.len() { + return Err(ZkasErr(format!( + "Opcode {} references literal idx {} but only {} literals exist", + opcode.name(), + heap_idx, + self.literals.len() + ))); + } + } + } + } + } + // Validate debug info consistency if present + if let Some(ref debug) = self.debug_info { + if debug.opcode_locations.len() != self.opcodes.len() { + return Err(ZkasErr(format!( + "Debug info has {} opcode locations but circuit has {} opcodes", + debug.opcode_locations.len(), + self.opcodes.len() + ))); + } + + if debug.heap_names.len() != heap_size { + return Err(ZkasErr(format!( + "Debug info has {} heap names but heap has {} entries", + debug.heap_names.len(), + heap_size + ))); + } + + if debug.literal_names.len() != self.literals.len() { + return Err(ZkasErr(format!( + "Debug info has {} literal names but {} literals exist", + debug.literal_names.len(), + self.literals.len() + ))); + } + } + + Ok(()) } fn parse_constants(bytes: &[u8]) -> Result> { @@ -202,12 +325,20 @@ impl ZkBinary { let mut offset = 0; while offset < bytes.len() { + // Check we haven't exceeded the limit + if constants.len() >= MAX_CONSTANTS { + return Err(ZkasErr(format!( + "Too many constants, maximum allowed is {MAX_CONSTANTS}" + ))) + } + let c_type = VarType::from_repr(bytes[offset]).ok_or_else(|| { ZkasErr(format!("Could not decode constant VarType from {}", bytes[offset])) })?; offset += 1; - let (name, len) = deserialize_partial::(&bytes[offset..])?; + let (name, len) = + deserialize_limited_partial::(&bytes[offset..], MAX_STRING_LEN)?; offset += len; constants.push((c_type, name)); @@ -221,12 +352,20 @@ impl ZkBinary { let mut offset = 0; while offset < bytes.len() { + // Check we haven't exceeded the limit + if literals.len() >= MAX_LITERALS { + return Err(ZkasErr(format!( + "Too many literals, maximum allowed is {MAX_LITERALS}" + ))); + } + let l_type = LitType::from_repr(bytes[offset]).ok_or_else(|| { ZkasErr(format!("Could not decode literal LitType from {}", bytes[offset])) })?; offset += 1; - let (name, len) = deserialize_partial::(&bytes[offset..])?; + let (name, len) = + deserialize_limited_partial::(&bytes[offset..], MAX_STRING_LEN)?; offset += len; literals.push((l_type, name)); @@ -236,7 +375,16 @@ impl ZkBinary { } fn parse_witnesses(bytes: &[u8]) -> Result> { - let mut witnesses = vec![]; + // Check vount before allocating + if bytes.len() > MAX_WITNESSES { + return Err(ZkasErr(format!( + "Too many witnesses ({}), maximum allowed is {}", + bytes.len(), + MAX_WITNESSES + ))); + } + + let mut witnesses = Vec::with_capacity(bytes.len()); for &byte in bytes { let w_type = VarType::from_repr(byte).ok_or_else(|| { @@ -255,6 +403,11 @@ impl ZkBinary { let mut offset = 0; while offset < bytes.len() { + // Check opcode count limit + if opcodes.len() >= MAX_OPCODES { + return Err(ZkasErr(format!("Too many opcodes, maximum allowed is {MAX_OPCODES}"))) + } + let opcode = Opcode::from_repr(bytes[offset]).ok_or_else(|| { ZkasErr(format!("Could not decode Opcode from {}", bytes[offset])) })?; @@ -266,9 +419,13 @@ impl ZkBinary { let (arg_count, len) = deserialize_partial::(&bytes[offset..])?; offset += len; + // Validate argument count + let arg_count = + validate_count(arg_count.0, MAX_ARGS_PER_OPCODE, bytes.len() - offset, "Argument")?; + // Parse arguments - let mut args = vec![]; - for _ in 0..arg_count.0 { + let mut args = Vec::with_capacity(arg_count); + for _ in 0..arg_count { // Check bounds to prevent panics if offset >= bytes.len() { return Err(ZkasErr(format!( @@ -296,6 +453,15 @@ impl ZkBinary { ZkasErr(format!("Could not decode HeapType from {}", heap_type_byte)) })?; + // Validate heap index is reasonable + let heap_idx = heap_index.0 as usize; + if heap_idx > MAX_HEAP_SIZE { + return Err(ZkasErr(format!( + "Heap index {} exceeds maximum allowed {}", + heap_idx, MAX_HEAP_SIZE + ))); + } + args.push((heap_type, heap_index.0 as usize)); } @@ -312,8 +478,11 @@ impl ZkBinary { let (num_opcodes, len) = deserialize_partial::(&bytes[offset..])?; offset += len; - let mut opcode_locations = Vec::with_capacity(num_opcodes.0 as usize); - for _ in 0..num_opcodes.0 { + let num_opcodes = + validate_count(num_opcodes.0, MAX_OPCODES, bytes.len() - offset, "Debug opcode")?; + + let mut opcode_locations = Vec::with_capacity(num_opcodes); + for _ in 0..num_opcodes { let (line, len) = deserialize_partial::(&bytes[offset..])?; offset += len; let (column, len) = deserialize_partial::(&bytes[offset..])?; @@ -325,9 +494,13 @@ impl ZkBinary { let (heap_size, len) = deserialize_partial::(&bytes[offset..])?; offset += len; - let mut heap_names = Vec::with_capacity(heap_size.0 as usize); - for _ in 0..heap_size.0 { - let (name, len) = deserialize_partial::(&bytes[offset..])?; + let heap_size = + validate_count(heap_size.0, MAX_HEAP_SIZE, bytes.len() - offset, "Debug heap")?; + + let mut heap_names = Vec::with_capacity(heap_size); + for _ in 0..heap_size { + let (name, len) = + deserialize_limited_partial::(&bytes[offset..], MAX_STRING_LEN)?; offset += len; heap_names.push(name); } @@ -336,9 +509,13 @@ impl ZkBinary { let (num_literals, len) = deserialize_partial::(&bytes[offset..])?; offset += len; - let mut literal_names = Vec::with_capacity(num_literals.0 as usize); - for _ in 0..num_literals.0 { - let (name, len) = deserialize_partial::(&bytes[offset..])?; + let num_literals = + validate_count(num_literals.0, MAX_LITERALS, bytes.len() - offset, "Debug literal")?; + + let mut literal_names = Vec::with_capacity(num_literals); + for _ in 0..num_literals { + let (name, len) = + deserialize_limited_partial::(&bytes[offset..], MAX_STRING_LEN)?; offset += len; literal_names.push(name); }