diff --git a/src/server.rs b/src/server.rs index 778772c..57d2ea2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -366,8 +366,22 @@ async fn handle_client( // --- TCP Test Server --- +/// Run a TCP bandwidth test on an already-authenticated stream. +/// Public API for use by server_pro. +pub async fn run_tcp_test( + stream: TcpStream, + cmd: Command, + state: Arc, +) -> Result<(u64, u64, u64, u32)> { + run_tcp_test_inner(stream, cmd, state).await +} + async fn run_tcp_test_server(stream: TcpStream, cmd: Command) -> Result<(u64, u64, u64, u32)> { let state = BandwidthState::new(); + run_tcp_test_inner(stream, cmd, state).await +} + +async fn run_tcp_test_inner(stream: TcpStream, cmd: Command, state: Arc) -> Result<(u64, u64, u64, u32)> { let tx_size = cmd.tx_size as usize; let server_should_tx = cmd.server_tx(); let server_should_rx = cmd.server_rx(); @@ -633,6 +647,18 @@ async fn tcp_status_sender( // --- UDP Test Server --- +/// Run a UDP bandwidth test on an already-authenticated stream. +/// Public API for use by server_pro. Caller provides the UDP port offset. +pub async fn run_udp_test( + stream: &mut TcpStream, + peer: SocketAddr, + cmd: &Command, + state: Arc, + udp_port_start: u16, +) -> Result<(u64, u64, u64, u32)> { + run_udp_test_inner(stream, peer, cmd, state, udp_port_start).await +} + async fn run_udp_test_server( stream: &mut TcpStream, peer: SocketAddr, @@ -640,7 +666,17 @@ async fn run_udp_test_server( udp_port_offset: Arc, ) -> Result<(u64, u64, u64, u32)> { let offset = udp_port_offset.fetch_add(1, Ordering::SeqCst); - let server_udp_port = BTEST_UDP_PORT_START + offset; + let state = BandwidthState::new(); + run_udp_test_inner(stream, peer, cmd, state, BTEST_UDP_PORT_START + offset).await +} + +async fn run_udp_test_inner( + stream: &mut TcpStream, + peer: SocketAddr, + cmd: &Command, + state: Arc, + server_udp_port: u16, +) -> Result<(u64, u64, u64, u32)> { let client_udp_port = server_udp_port + BTEST_PORT_CLIENT_OFFSET; stream.write_all(&server_udp_port.to_be_bytes()).await?; @@ -707,7 +743,6 @@ async fn run_udp_test_server( if use_unconnected { "unconnected" } else { "connected" }, ); - let state = BandwidthState::new(); let tx_size = cmd.tx_size as usize; let server_should_tx = cmd.server_tx(); let server_should_rx = cmd.server_rx(); diff --git a/src/server_pro/server_loop.rs b/src/server_pro/server_loop.rs index 0f4605d..3823b83 100644 --- a/src/server_pro/server_loop.rs +++ b/src/server_pro/server_loop.rs @@ -232,24 +232,49 @@ async fn handle_pro_client( quota_mgr.max_duration(), ); + // Spawn quota enforcer — runs alongside the test + let enforcer_state = state.clone(); let enforcer_handle = tokio::spawn(async move { enforcer.run().await }); - // Run the actual bandwidth test using the standard server - // For now, delegate to the standard TCP/UDP handlers - // by using the existing btest_rs::server internals. - // The state's `running` flag will be set to false by the enforcer - // when quota is exceeded, which will stop the TX/RX loops. + // Run the actual bandwidth test using the standard server handlers. + // The enforcer runs concurrently and will set state.running = false + // if any quota is exceeded, which gracefully stops the TX/RX loops. + static UDP_PORT_OFFSET: std::sync::atomic::AtomicU16 = std::sync::atomic::AtomicU16::new(0); - // TODO: Integrate more deeply with btest_rs::server to pass the shared state - // For now, we simulate by waiting for the enforcer to finish + let test_result = if cmd.is_udp() { + let offset = UDP_PORT_OFFSET.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let udp_port = btest_rs::protocol::BTEST_UDP_PORT_START + offset; + btest_rs::server::run_udp_test( + &mut stream, peer, &cmd, state.clone(), udp_port, + ).await + } else { + btest_rs::server::run_tcp_test(stream, cmd.clone(), state.clone()).await + }; + + // Test finished — stop the enforcer if still running + enforcer_state.running.store(false, std::sync::atomic::Ordering::SeqCst); let stop_reason = enforcer_handle.await.unwrap_or(StopReason::ClientDisconnected); + // Determine final stop reason + let final_reason = match &test_result { + Ok(_) => { + if stop_reason == StopReason::ClientDisconnected { + StopReason::ClientDisconnected + } else { + stop_reason // quota or duration exceeded + } + } + Err(_) => StopReason::ClientDisconnected, + }; + // Record final usage let (total_tx, total_rx, _, _) = state.summary(); + + // Flush to DB quota_mgr.record_usage(&username, &peer.ip().to_string(), total_tx, total_rx); db.end_session(session_id, total_tx, total_rx)?; - Ok((username, stop_reason, total_tx, total_rx)) + Ok((username, final_reason, total_tx, total_rx)) }