From ca3372909a22be1ec3ad7d9a300ebb21e7bc0eef Mon Sep 17 00:00:00 2001 From: LambdaClass <121504986+lambdaclass-user@users.noreply.github.com> Date: Fri, 6 Jan 2023 21:28:07 -0300 Subject: [PATCH] feat(cli): add more convenient SocketAddr argument parsing (#757) --- bin/reth/src/node/mod.rs | 3 ++- bin/reth/src/util/mod.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/bin/reth/src/node/mod.rs b/bin/reth/src/node/mod.rs index fdc116671e..ff4f297ceb 100644 --- a/bin/reth/src/node/mod.rs +++ b/bin/reth/src/node/mod.rs @@ -8,6 +8,7 @@ use crate::{ util::{ chainspec::{chain_spec_value_parser, ChainSpecification}, init::{init_db, init_genesis}, + socketaddr_value_parser, }, NetworkOpts, }; @@ -65,7 +66,7 @@ pub struct Command { /// Enable Prometheus metrics. /// /// The metrics will be served at the given interface and port. - #[clap(long, value_name = "SOCKET")] + #[arg(long, value_name = "SOCKET", value_parser = socketaddr_value_parser)] metrics: Option, /// Set the chain tip manually for testing purposes. diff --git a/bin/reth/src/util/mod.rs b/bin/reth/src/util/mod.rs index 72a3847c9c..cade43f928 100644 --- a/bin/reth/src/util/mod.rs +++ b/bin/reth/src/util/mod.rs @@ -2,6 +2,7 @@ use reth_primitives::{BlockHashOrNumber, H256}; use std::{ env::VarError, + net::{SocketAddr, ToSocketAddrs}, path::{Path, PathBuf}, str::FromStr, }; @@ -37,6 +38,29 @@ pub(crate) fn hash_or_num_value_parser(value: &str) -> Result Result { + const DEFAULT_DOMAIN: &str = "localhost"; + const DEFAULT_PORT: u16 = 9000; + let value = if value.is_empty() || value == ":" { + format!("{DEFAULT_DOMAIN}:{DEFAULT_PORT}") + } else if value.starts_with(':') { + format!("{DEFAULT_DOMAIN}{value}") + } else if value.ends_with(':') { + format!("{value}{DEFAULT_PORT}") + } else if value.parse::().is_ok() { + format!("{DEFAULT_DOMAIN}:{value}") + } else if value.contains(':') { + value.to_string() + } else { + format!("{value}:{DEFAULT_PORT}") + }; + match value.to_socket_addrs() { + Ok(mut iter) => iter.next().ok_or(eyre::Error::msg(format!("\"{value}\""))), + Err(e) => Err(eyre::Error::from(e).wrap_err(format!("\"{value}\""))), + } +} + /// Tracing utility pub mod reth_tracing { use tracing::Subscriber; @@ -102,3 +126,19 @@ pub mod reth_tracing { .with(filter) } } + +#[cfg(test)] +mod tests { + use std::net::ToSocketAddrs; + + use super::socketaddr_value_parser; + + #[test] + fn parse_socketaddr_with_default() { + let expected = "localhost:9000".to_socket_addrs().unwrap().next().unwrap(); + let test_values = ["localhost:9000", ":9000", "9000", "localhost:", "localhost", ":", ""]; + for value in test_values { + assert_eq!(socketaddr_value_parser(value).expect("value_parser failed"), expected); + } + } +}