use futures-plex poll_get and poll_mut

This commit is contained in:
th4s
2025-12-17 15:52:23 +01:00
parent e11180259f
commit fd5ef24580
3 changed files with 48 additions and 56 deletions

3
Cargo.lock generated
View File

@@ -3204,9 +3204,8 @@ dependencies = [
[[package]]
name = "futures-plex"
version = "0.1.0"
source = "git+https://github.com/tlsnotary/tlsn-utils?rev=7ca0b13#7ca0b132787a31faa0e91c816d6603d40db7e526"
source = "git+https://github.com/tlsnotary/tlsn-utils?rev=c210f2f#c210f2fdd0a5d71c3e217fa03127c9f616314836"
dependencies = [
"bytes",
"futures",
]

View File

@@ -80,7 +80,7 @@ mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "7ca0b13" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "c210f2f" }
rangeset = { version = "0.4" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }

View File

@@ -10,6 +10,7 @@ pub mod state;
pub use conn::TlsConnection;
pub use control::ProverControl;
pub use error::ProverError;
use futures_plex::DuplexStream;
pub use tlsn_core::ProverOutput;
use crate::{
@@ -327,13 +328,13 @@ where
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = Pin::new(&mut self.state).project();
let mut tmp_buf = [0_u8; BUF_CAP];
let (mut duplex_1, mut duplex_2) = futures_plex::duplex(BUF_CAP);
loop {
let mut progress = false;
progress |= Self::io_client_conn(&mut state, cx)?;
progress |= Self::io_client_server(&mut state, cx, &mut tmp_buf)?;
progress |= Self::io_client_server(&mut state, cx, &mut duplex_1, &mut duplex_2)?;
progress |= Self::io_client_verifier(&mut state, cx)?;
_ = state.mux_fut.poll_unpin(cx)?;
@@ -363,16 +364,13 @@ where
) -> Result<bool, ProverError> {
let mut progress = false;
if let Poll::Ready(read) = state.client_io.as_mut().poll_read_flex(cx, |_, buf| {
Poll::Ready(
state
.tls_client
.write(buf)
.map_err(|err| std::io::Error::other(err)),
)
})? {
if read > 0 {
// tls_conn -> tls_client
if let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
{
if buf.len() > 0 {
progress = true;
state.tls_client.write(buf)?;
} else if !*state.client_closed && !*state.server_closed {
progress = true;
*state.client_closed = true;
@@ -380,20 +378,13 @@ where
}
}
if let Poll::Ready(write) = state.client_io.as_mut().poll_write_flex(cx, |_, buf| {
Poll::Ready(
state
.tls_client
.read(buf)
.map_err(|err| std::io::Error::other(err)),
)
})? {
if write > 0 {
progress = true;
} else if !*state.client_closed
&& let Poll::Ready(()) = state.client_io.as_mut().poll_close(cx)?
{
// tls_client -> tls_conn
if let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
{
if buf.len() > 0 {
progress = true;
state.tls_client.read(buf)?;
}
}
@@ -403,14 +394,29 @@ where
fn io_client_server(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
tmp_buf: &mut [u8],
duplex_1: &mut DuplexStream,
duplex_2: &mut DuplexStream,
) -> Result<bool, ProverError> {
let mut progress = false;
let mut duplex_1 = Pin::new(duplex_1);
let mut duplex_2 = Pin::new(duplex_2);
if let Poll::Ready(read) = state.server_socket.as_mut().poll_read(cx, tmp_buf)? {
if read > 0 {
// server_socket -> duplex
if let Poll::Ready(write) = duplex_1.poll_write_from(cx, state.server_socket.as_mut())? {
if write > 0 {
progress = true;
state.server_to_client_buffer.extend(&tmp_buf[..read]);
} else if let Poll::Ready(()) = duplex_1.as_mut().poll_close(cx)? {
progress = true;
}
}
// duplex -> tls_client
if let Poll::Ready(mut simplex) = duplex_1.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
{
if buf.len() > 0 {
progress = true;
state.tls_client.read_tls(buf)?;
} else if !*state.server_closed {
progress = true;
*state.server_closed = true;
@@ -418,38 +424,23 @@ where
}
}
if state.server_to_client_buffer.len() > 0 && state.tls_client.wants_read_tls() {
progress = true;
let write = state.tls_client.read_tls(&state.server_to_client_buffer)?;
state.server_to_client_buffer.drain(..write);
// tls_client -> duplex
if let Poll::Ready(mut simplex) = duplex_2.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
{
if buf.len() > 0 {
progress = true;
state.tls_client.write_tls(buf)?;
}
}
if state.tls_client.wants_write_tls() {
let read = state.tls_client.write_tls(tmp_buf)?;
// duplex -> server_socket
if let Poll::Ready(read) = duplex_2.poll_read_to(cx, state.server_socket.as_mut())? {
if read > 0 {
progress = true;
state
.client_to_server_buffer
.extend_from_slice(&tmp_buf[..read]);
}
}
if *state.server_closed && !*state.client_closed {
progress = true;
if let Poll::Ready(()) = state.client_io.as_mut().poll_close(cx)? {
*state.client_closed = true;
}
}
if state.client_to_server_buffer.len() > 0
&& let Poll::Ready(write) = state
.server_socket
.as_mut()
.poll_write(cx, &state.client_to_server_buffer)?
&& write > 0
{
progress = true;
state.client_to_server_buffer.drain(..write);
}
Ok(progress)
}
@@ -465,12 +456,14 @@ where
.expect("verifier io should be available"),
);
// mux -> verifier_socket
if let Poll::Ready(read) = verifier_io.poll_read_to(cx, state.verifier_socket.as_mut())? {
if read > 0 {
progress = true;
}
}
// verifier_socket -> mux
if let Poll::Ready(write) =
verifier_io.poll_write_from(cx, state.verifier_socket.as_mut())?
{