Make rv64 label detection more idiomatic (#3072)

Use iterators instead of while loop and globa state.
This commit is contained in:
Thibaut Schaeffer
2025-07-22 13:54:22 +02:00
committed by GitHub
parent 4446c7298f
commit d308ae166e

View File

@@ -123,16 +123,31 @@ fn scan_for_jump_targets(
jumpdests_with_debug_info: &mut BTreeMap<u64, Vec<JumpDest>>,
label_addrs: &BTreeSet<u64>,
) {
let mut addr = base_addr;
let mut remaining = data;
let mut last_was_auipc = false;
data.chunks(4)
// Cast to [u8; 4]
.map(|data| data.try_into().unwrap())
.inspect(|data: &[u8; 4]| {
assert!(data[0] & 0b11 == 0b11, "Expected 32-bit instruction");
})
.map(u32::from_le_bytes)
// Decode the instruction bytes
.map(|insn_bytes| {
insn_bytes
.decode(Isa::Rv64)
.expect("Failed to decode instruction")
})
// Remember the `rs1` and `imm` of the previous instruction if it was AUIPC, used to propagate it to the next JALR
.scan(None, |previous_if_auipc, insn| {
let previous_auipc_rs1 = std::mem::replace(
previous_if_auipc,
matches!(insn.opc, Op::AUIPC).then_some((insn.rs1, insn.imm)),
);
Some((insn, previous_auipc_rs1))
})
.enumerate()
.for_each(|(instruction_index, (insn, previous_if_auipc))| {
let addr = base_addr + (instruction_index * 4) as u64;
while remaining.len() >= 4 {
// Assert that we have a 32-bit instruction.
assert!(remaining[0] & 0b11 == 0b11);
let insn_bytes = u32::from_le_bytes(remaining[0..4].try_into().unwrap());
if let Ok(insn) = insn_bytes.decode(Isa::Rv64) {
// Check for jump/branch instructions
match insn.opc {
Op::JAL => {
@@ -189,40 +204,30 @@ fn scan_for_jump_targets(
}
}
}
Op::AUIPC => {
// AUIPC is often followed by JALR for function calls and long jumps
// In statically linked binaries, these usually target known symbols
if remaining.len() >= 8 {
let next_insn_bytes =
u32::from_le_bytes(remaining[4..8].try_into().unwrap());
if let Ok(next_insn) = next_insn_bytes.decode(Isa::Rv64) {
if matches!(next_insn.opc, Op::JALR) && insn.rd == next_insn.rs1 {
// This is an AUIPC+JALR pair
if let (Some(auipc_imm), Some(jalr_imm)) = (insn.imm, next_insn.imm)
{
let target =
(addr as i64 + auipc_imm as i64 + jalr_imm as i64) as u64;
jumpdests.insert(target);
Op::JALR => {
if let Some((rs1, imm)) = previous_if_auipc {
// JALR with a preceding AUIPC
if insn.rd == rs1 {
// This is an AUIPC+JALR pair, we can resolve it statically
if let (Some(auipc_imm), Some(jalr_imm)) = (imm, insn.imm) {
let target =
(addr as i64 + auipc_imm as i64 + jalr_imm as i64) as u64;
jumpdests.insert(target);
// Track non-symbol jumpdests
if !label_addrs.contains(&target) {
let jump_info = JumpDest {
from_addr: addr,
instruction: format!("auipc+jalr -> 0x{target:x}"),
};
jumpdests_with_debug_info
.entry(target)
.or_default()
.push(jump_info);
}
// Track non-symbol jumpdests
if !label_addrs.contains(&target) {
let jump_info = JumpDest {
from_addr: addr,
instruction: format!("auipc+jalr -> 0x{target:x}"),
};
jumpdests_with_debug_info
.entry(target)
.or_default()
.push(jump_info);
}
}
}
}
}
Op::JALR => {
// Only process if this JALR is not part of an AUIPC+JALR pair
if !last_was_auipc {
} else {
// Standalone JALR without preceding AUIPC
// These are dynamic jumps we can't resolve statically:
// - Return instructions (jalr x0, x1, 0)
@@ -249,14 +254,6 @@ fn scan_for_jump_targets(
}
}
_ => {}
}
// Update for next iteration
last_was_auipc = matches!(insn.opc, Op::AUIPC);
} else {
panic!("Could not decode instruction")
}
addr += 4;
remaining = &remaining[4..];
}
};
});
}