diff --git a/src/client.rs b/src/client.rs index d960278..1fe21dd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,6 @@ use crate::auth; use crate::bandwidth::{self, BandwidthState}; use crate::protocol::*; -/// Returns (total_tx_bytes, total_rx_bytes, total_lost_packets, duration_secs). pub async fn run_client( host: &str, port: u16, @@ -21,7 +20,8 @@ pub async fn run_client( auth_user: Option, auth_pass: Option, nat_mode: bool, -) -> Result<(u64, u64, u64, u32)> { + shared_state: Arc, +) -> Result<()> { let addr = format!("{}:{}", host, port); tracing::info!("Connecting to {}...", addr); let mut stream = TcpStream::connect(&addr).await?; @@ -91,16 +91,15 @@ pub async fn run_client( ); if use_udp { - run_udp_test_client(&mut stream, host, &cmd, nat_mode).await + run_udp_test_client(&mut stream, host, &cmd, nat_mode, shared_state).await } else { - run_tcp_test_client(stream, cmd).await + run_tcp_test_client(stream, cmd, shared_state).await } } // --- TCP Test Client --- -async fn run_tcp_test_client(stream: TcpStream, cmd: Command) -> Result<(u64, u64, u64, u32)> { - let state = BandwidthState::new(); +async fn run_tcp_test_client(stream: TcpStream, cmd: Command, state: Arc) -> Result<()> { let tx_size = cmd.tx_size as usize; let client_should_tx = cmd.client_tx(); let client_should_rx = cmd.client_rx(); @@ -138,7 +137,7 @@ async fn run_tcp_test_client(stream: TcpStream, cmd: Command) -> Result<(u64, u6 state.running.store(false, Ordering::SeqCst); if let Some(h) = tx_handle { let _ = h.await; } if let Some(h) = rx_handle { let _ = h.await; } - Ok(state.summary()) + Ok(()) } async fn tcp_client_tx_loop( @@ -203,7 +202,8 @@ async fn run_udp_test_client( host: &str, cmd: &Command, nat_mode: bool, -) -> Result<(u64, u64, u64, u32)> { + state: Arc, +) -> Result<()> { let mut port_buf = [0u8; 2]; stream.read_exact(&mut port_buf).await?; let server_udp_port = u16::from_be_bytes(port_buf); @@ -234,7 +234,6 @@ async fn run_udp_test_client( udp.send(&[]).await?; } - let state = BandwidthState::new(); let tx_size = cmd.tx_size as usize; let client_should_tx = cmd.client_tx(); let client_should_rx = cmd.client_rx(); @@ -266,7 +265,7 @@ async fn run_udp_test_client( state.running.store(false, Ordering::SeqCst); if let Some(h) = tx_handle { let _ = h.await; } if let Some(h) = rx_handle { let _ = h.await; } - Ok(state.summary()) + Ok(()) } async fn udp_client_tx_loop( diff --git a/src/main.rs b/src/main.rs index f471d21..8ebae03 100644 --- a/src/main.rs +++ b/src/main.rs @@ -172,6 +172,9 @@ async fn main() -> anyhow::Result<()> { }; let proto_str = if cli.udp { "UDP" } else { "TCP" }; + // Create shared state that survives timeout cancellation + let shared_state = bandwidth::BandwidthState::new(); + // Log test start syslog_logger::test_start(&host, proto_str, dir_str, 0); @@ -187,24 +190,28 @@ async fn main() -> anyhow::Result<()> { cli.auth_user.clone(), cli.auth_pass.clone(), cli.nat, + shared_state.clone(), ); - let stats = if cli.duration > 0 { + if cli.duration > 0 { match tokio::time::timeout( std::time::Duration::from_secs(cli.duration), client_fut, ) .await { - Ok(result) => Some(result?), - Err(_) => None, // Timeout — stats not available from aborted future + Ok(result) => { let _ = result?; }, + Err(_) => { + // Timeout — signal stop + shared_state.running.store(false, std::sync::atomic::Ordering::SeqCst); + } } } else { - Some(client_fut.await?) - }; + let _ = client_fut.await?; + } let elapsed = start.elapsed().as_secs(); - let (total_tx, total_rx, total_lost, _intervals) = stats.unwrap_or((0, 0, 0, 0)); + let (total_tx, total_rx, total_lost, _intervals) = shared_state.summary(); // Log test end to syslog syslog_logger::test_end( diff --git a/tests/ecsrp5_test.rs b/tests/ecsrp5_test.rs index 21efa03..8296cbd 100644 --- a/tests/ecsrp5_test.rs +++ b/tests/ecsrp5_test.rs @@ -85,6 +85,7 @@ async fn test_ecsrp5_full_client_auth() { Some("testuser".into()), Some("testpass".into()), false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -109,6 +110,7 @@ async fn test_ecsrp5_wrong_password_fails() { Some("testuser".into()), Some("wrongpass".into()), false, + btest_rs::bandwidth::BandwidthState::new(), ) .await; @@ -131,6 +133,7 @@ async fn test_md5_auth_still_works() { Some("testuser".into()), Some("testpass".into()), false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -155,6 +158,7 @@ async fn test_noauth_still_works() { None, None, false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -179,6 +183,7 @@ async fn test_ecsrp5_udp_bidirectional() { Some("testuser".into()), Some("testpass".into()), false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 5c397f2..cdc9834 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -153,6 +153,7 @@ async fn test_loopback_tcp_rx() { None, None, false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -177,6 +178,7 @@ async fn test_loopback_tcp_tx() { None, None, false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -201,6 +203,7 @@ async fn test_loopback_tcp_both() { None, None, false, + btest_rs::bandwidth::BandwidthState::new(), ) .await }); @@ -225,6 +228,7 @@ async fn test_loopback_tcp_with_auth() { Some("admin".into()), Some("secret".into()), false, + btest_rs::bandwidth::BandwidthState::new(), ) .await });