Print macro and panic handler safety.

This commit is contained in:
Lucas Clemente Vella
2023-05-19 14:36:00 +01:00
parent b4ff289d8a
commit 1ce602d4af
5 changed files with 65 additions and 92 deletions

40
riscv/runtime/src/fmt.rs Normal file
View File

@@ -0,0 +1,40 @@
use core::arch::asm;
use core::fmt;
#[macro_export]
macro_rules! print {
($($arg:tt)+) => {{
$crate::fmt::print_args(format_args!( $($arg)+));
}};
}
pub fn print_args(args: fmt::Arguments) {
fmt::write(&mut ProverWriter {}, args).unwrap();
}
struct ProverWriter {}
impl fmt::Write for ProverWriter {
fn write_str(&mut self, s: &str) -> fmt::Result {
print_str(s);
Ok(())
}
}
pub fn print_str(s: &str) {
// DEAR DEV, please don't allow this function to panic.
//
// This is called from the panic handler.
for b in s.bytes() {
print_prover_char(b)
}
}
#[inline]
fn print_prover_char(c: u8) {
let mut value = c as u32;
#[allow(unused_assignments)]
unsafe {
asm!("ebreak", lateout("a0") value, in("a0") value);
}
}

View File

@@ -1,19 +1,32 @@
#![no_std]
#![feature(start, alloc_error_handler)]
#![feature(
start,
alloc_error_handler,
maybe_uninit_write_slice,
round_char_boundary
)]
use core::arch::asm;
use core::panic::PanicInfo;
use crate::fmt::print_str;
mod allocator;
mod print;
pub use print::print;
pub mod fmt;
#[panic_handler]
fn panic(panic: &PanicInfo<'_>) -> ! {
print(format_args!("{panic}"));
unsafe {
asm!("unimp");
unsafe fn panic(panic: &PanicInfo<'_>) -> ! {
static mut IS_PANICKING: bool = false;
if !IS_PANICKING {
IS_PANICKING = true;
print!("{panic}\n");
} else {
print_str("Panic handler has panicked! Things are very dire indeed...\n");
}
asm!("unimp");
loop {}
}

View File

@@ -1,81 +0,0 @@
use core::arch::asm;
use core::fmt;
use core::mem::MaybeUninit;
// #[macro_export]
// macro_rules! print {
// ($($arg:tt)+) => (print_args(format_args!( $($arg)+)))
// }
// TODO turn this into a macro
pub fn print(args: fmt::Arguments) {
const BUF_SIZE: usize = 1024;
let mut buf = unsafe { MaybeUninit::<[MaybeUninit<u8>; BUF_SIZE]>::uninit().assume_init() };
let _s: &str = buf_formatter::format(&mut buf, args).unwrap();
print_prover(_s);
}
#[inline]
fn print_prover(s: &str) {
for b in s.bytes() {
print_prover_char(b)
}
}
#[inline]
fn print_prover_char(c: u8) {
let mut value = c as u32;
#[allow(unused_assignments)]
unsafe {
asm!("ebreak", lateout("a0") value, in("a0") value);
}
}
mod buf_formatter {
use core::fmt;
use core::mem::MaybeUninit;
pub struct BufFormatter<'a> {
buffer: &'a mut [MaybeUninit<u8>],
used: usize,
}
impl<'a> BufFormatter<'a> {
pub fn new(buffer: &'a mut [MaybeUninit<u8>]) -> Self {
Self { buffer, used: 0 }
}
pub fn as_str(&self) -> &'a str {
unsafe {
// This is safe because everything until self.used has been initialized.
let buf =
&*(&self.buffer[..self.used] as *const [MaybeUninit<u8>] as *const [u8]);
// we only concatenate str, so the result must be valid utf8 as well.
core::str::from_utf8_unchecked(buf)
}
}
}
impl<'a> fmt::Write for BufFormatter<'a> {
fn write_str(&mut self, s: &str) -> fmt::Result {
let raw_s = s.as_bytes();
let write_len = raw_s.len();
if write_len > self.buffer.len() - self.used {
return Err(fmt::Error);
}
let s_uninit: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(&raw_s[..write_len]) };
self.buffer[self.used..][..write_len].copy_from_slice(s_uninit);
self.used += write_len;
Ok(())
}
}
pub fn format<'a>(
buffer: &'a mut [MaybeUninit<u8>],
args: fmt::Arguments,
) -> Result<&'a str, fmt::Error> {
let mut w = BufFormatter::new(buffer);
fmt::write(&mut w, args)?;
Ok(w.as_str())
}
}

View File

@@ -164,9 +164,9 @@ runtime = {{ path = "./runtime" }}
)
.unwrap();
let mut print_file = runtime_file.clone();
print_file.push("print.rs");
fs::write(print_file, include_bytes!("../runtime/src/print.rs")).unwrap();
let mut fmt_file = runtime_file.clone();
fmt_file.push("fmt.rs");
fs::write(fmt_file, include_bytes!("../runtime/src/fmt.rs")).unwrap();
compile_rust_crate_to_riscv_asm(cargo_file.to_str().unwrap())
}

View File

@@ -16,7 +16,7 @@
extern crate alloc;
use alloc::vec::Vec;
use runtime::get_prover_input;
use runtime::{get_prover_input, print};
#[no_mangle]
fn main() {
@@ -33,5 +33,6 @@ fn main() {
(vec[half - 1] + vec[half]) / 2
};
print!("Found median of {median}\n");
assert_eq!(median, expected);
}