diff --git a/Cargo.toml b/Cargo.toml index 3aaff9d..05b4380 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,8 @@ num-traits = "0.2.19" num-integer = "0.1.46" sha2 = "0.11.0" hostname = "0.4.2" +chrono = "0.4" +memchr = "2" rusqlite = { version = "0.39.0", features = ["bundled"], optional = true } ldap3 = { version = "0.12.1", optional = true } axum = { version = "0.8.8", features = ["tokio"], optional = true } diff --git a/src/csv_output.rs b/src/csv_output.rs index e1993b7..036650b 100644 --- a/src/csv_output.rs +++ b/src/csv_output.rs @@ -9,7 +9,7 @@ use std::path::Path; use std::sync::Mutex; use std::time::SystemTime; -static CSV_FILE: Mutex> = Mutex::new(None); +static CSV_FILE: Mutex> = Mutex::new(None); static QUIET: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); const HEADER: &str = "timestamp,host,port,protocol,direction,duration_s,tx_avg_mbps,rx_avg_mbps,tx_bytes,rx_bytes,lost_packets,local_cpu_pct,remote_cpu_pct,auth_type"; @@ -18,12 +18,12 @@ const HEADER: &str = "timestamp,host,port,protocol,direction,duration_s,tx_avg_m pub fn init(path: &str) -> std::io::Result<()> { let needs_header = !Path::new(path).exists() || std::fs::metadata(path)?.len() == 0; + let mut f = OpenOptions::new().create(true).append(true).open(path)?; if needs_header { - let mut f = OpenOptions::new().create(true).write(true).open(path)?; writeln!(f, "{}", HEADER)?; } - *CSV_FILE.lock().unwrap() = Some(path.to_string()); + *CSV_FILE.lock().unwrap() = Some((path.to_string(), f)); Ok(()) } @@ -49,8 +49,8 @@ pub fn write_result( remote_cpu: u8, auth_type: &str, ) { - let guard = CSV_FILE.lock().unwrap(); - if let Some(ref path) = *guard { + let mut guard = CSV_FILE.lock().unwrap(); + if let Some((ref _path, ref mut file)) = *guard { let tx_mbps = if duration_secs > 0 { tx_bytes as f64 * 8.0 / duration_secs as f64 / 1_000_000.0 } else { @@ -74,9 +74,8 @@ pub fn write_result( local_cpu, remote_cpu, auth_type, ); - if let Ok(mut f) = OpenOptions::new().append(true).open(path) { - let _ = writeln!(f, "{}", row); - } + let _ = writeln!(file, "{}", row); + let _ = file.flush(); } } diff --git a/src/ecsrp5.rs b/src/ecsrp5.rs index badc7d1..a3f19d6 100644 --- a/src/ecsrp5.rs +++ b/src/ecsrp5.rs @@ -42,6 +42,8 @@ static WEIERSTRASS_A: LazyLock = LazyLock::new(|| { .unwrap() }); +static WCURVE: LazyLock = LazyLock::new(WCurve::new); + const MONT_A: u64 = 486662; // --- Modular arithmetic --- @@ -360,7 +362,7 @@ pub async fn client_authenticate( password: &str, ) -> Result<()> { tracing::info!("Starting EC-SRP5 authentication"); - let w = WCurve::new(); + let w = &*WCURVE; // Generate client ephemeral keypair let s_a: [u8; 32] = rand::random(); @@ -477,7 +479,7 @@ impl EcSrp5Credentials { /// Derive EC-SRP5 credentials from username/password (done once at startup). pub fn derive(username: &str, password: &str) -> Self { let salt: [u8; 16] = rand::random(); - let w = WCurve::new(); + let w = &*WCURVE; let i = w.gen_password_validator_priv(username, password, &salt); let (x_gamma, parity) = w.gen_public_key(&i); Self { @@ -496,7 +498,7 @@ pub async fn server_authenticate( creds: &EcSrp5Credentials, ) -> Result<()> { tracing::info!("Starting EC-SRP5 server authentication"); - let w = WCurve::new(); + let w = &*WCURVE; // MSG1: read [len][username\0][pubkey:32][parity:1] let mut len_buf = [0u8; 1]; @@ -599,7 +601,12 @@ pub async fn server_authenticate( mod hex { pub fn encode(data: &[u8]) -> String { - data.iter().map(|b| format!("{:02x}", b)).collect() + let mut s = String::with_capacity(data.len() * 2); + for b in data { + use std::fmt::Write; + let _ = write!(s, "{:02x}", b); + } + s } } diff --git a/src/server.rs b/src/server.rs index b550551..70a7ec5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,6 +4,8 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Notify; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::sync::Mutex; @@ -18,6 +20,7 @@ struct TcpSession { peer_ip: std::net::IpAddr, streams: Vec, expected: u8, + notify: Arc, } type SessionMap = Arc>>; @@ -169,6 +172,7 @@ async fn handle_client( stream.flush().await?; session.streams.push(stream); + session.notify.notify_one(); tracing::info!( "Secondary connection joined ({}/{})", session.streams.len() + 1, @@ -249,6 +253,7 @@ async fn handle_client( for (_t, s) in map.iter_mut() { if s.peer_ip == peer.ip() && s.streams.len() < s.expected as usize { s.streams.push(stream); + s.notify.notify_one(); return Ok(()); } } @@ -299,12 +304,14 @@ async fn handle_client( let conn_count = cmd.tcp_conn_count; // Register session for secondary connections to find + let notify = Arc::new(Notify::new()); { let mut map = sessions.lock().await; map.insert(session_token, TcpSession { peer_ip: peer.ip(), streams: Vec::new(), expected: conn_count, + notify: notify.clone(), }); } @@ -320,7 +327,8 @@ async fn handle_client( if count + 1 >= conn_count as usize { break; } - if Instant::now() > deadline { + let now = Instant::now(); + if now >= deadline { tracing::warn!( "Timeout waiting for TCP connections ({}/{}), proceeding", count + 1, @@ -328,7 +336,17 @@ async fn handle_client( ); break; } - tokio::time::sleep(Duration::from_millis(100)).await; + match tokio::time::timeout(deadline - now, notify.notified()).await { + Ok(()) => continue, + Err(_) => { + tracing::warn!( + "Timeout waiting for TCP connections ({}/{}), proceeding", + count + 1, + conn_count, + ); + break; + } + } } let extra_streams = { @@ -589,8 +607,10 @@ async fn tcp_tx_loop_inner( let mut status_seq: u32 = 0; while state.running.load(Ordering::Relaxed) { + let now = Instant::now(); + // Inject status message every ~1 second if in bidirectional mode - if send_status && Instant::now() >= next_status { + if send_status && now >= next_status { status_seq += 1; let rx_bytes = state.rx_bytes.swap(0, Ordering::Relaxed); let status = StatusMessage { cpu_load: crate::cpu::get(), @@ -603,7 +623,7 @@ async fn tcp_tx_loop_inner( } state.record_interval(0, rx_bytes, 0); bandwidth::print_status(status_seq, "RX", rx_bytes, Duration::from_secs(1), None); - next_status = Instant::now() + Duration::from_secs(1); + next_status = now + Duration::from_secs(1); } if !state.spend_budget(effective_size as u64) { @@ -619,12 +639,11 @@ async fn tcp_tx_loop_inner( state.tx_speed_changed.store(false, Ordering::Relaxed); let new_speed = state.tx_speed.load(Ordering::Relaxed); interval = bandwidth::calc_send_interval(new_speed, tx_size as u16); - next_send = Instant::now(); + next_send = now; } match interval { Some(iv) => { - let now = Instant::now(); if let Some(delay) = bandwidth::advance_next_send(&mut next_send, iv, now) { tokio::time::sleep(delay).await; } @@ -918,36 +937,43 @@ async fn udp_tx_loop( async fn udp_rx_loop(socket: &UdpSocket, state: Arc) { let mut buf = vec![0u8; 65536]; let mut last_seq: Option = None; + let mut timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); while state.running.load(Ordering::Relaxed) { - // Use recv_from to accept packets from any source port - // (multi-connection MikroTik sends from multiple ports) - match tokio::time::timeout(Duration::from_secs(5), socket.recv_from(&mut buf)).await { - Ok(Ok((n, _src))) if n >= 4 => { - if !state.spend_budget(n as u64) { - break; - } - state.rx_bytes.fetch_add(n as u64, Ordering::Relaxed); - state.rx_packets.fetch_add(1, Ordering::Relaxed); + tokio::select! { + biased; + res = socket.recv_from(&mut buf) => { + match res { + Ok((n, _src)) if n >= 4 => { + if !state.spend_budget(n as u64) { + return; + } + state.rx_bytes.fetch_add(n as u64, Ordering::Relaxed); + state.rx_packets.fetch_add(1, Ordering::Relaxed); - let seq = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]); - if let Some(last) = last_seq { - let expected = last.wrapping_add(1); - if seq > expected { - let lost = seq - expected; - state.rx_lost_packets.fetch_add(lost as u64, Ordering::Relaxed); + let seq = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]); + if let Some(last) = last_seq { + let expected = last.wrapping_add(1); + if seq > expected { + let lost = seq - expected; + state.rx_lost_packets.fetch_add(lost as u64, Ordering::Relaxed); + } + } + last_seq = Some(seq); + state.last_udp_seq.store(seq, Ordering::Relaxed); + } + Ok(_) => {} + Err(e) => { + tracing::debug!("UDP recv error: {}", e); + tokio::time::sleep(Duration::from_millis(10)).await; } } - last_seq = Some(seq); - state.last_udp_seq.store(seq, Ordering::Relaxed); + timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_secs(5)); } - Ok(Ok(_)) => {} - Ok(Err(e)) => { - tracing::debug!("UDP recv error: {}", e); - tokio::time::sleep(Duration::from_millis(10)).await; - } - Err(_) => { + _ = &mut timeout => { tracing::debug!("UDP RX timeout"); + timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_secs(5)); } } } diff --git a/src/server_pro/enforcer.rs b/src/server_pro/enforcer.rs index 2b05847..a2a7b93 100644 --- a/src/server_pro/enforcer.rs +++ b/src/server_pro/enforcer.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use btest_rs::bandwidth::BandwidthState; -use super::quota::{Direction, QuotaManager}; +use super::quota::{Direction, QuotaError, QuotaManager}; /// Enforces quotas during an active test session. /// Call `run()` as a spawned task — it will set `state.running = false` @@ -154,10 +154,10 @@ impl QuotaEnforcer { // The DB has usage from PREVIOUS sessions; we add current session bytes if let Err(e) = self.quota_mgr.check_user(&self.username) { // Already exceeded from previous sessions - return match format!("{}", e).as_str() { - s if s.contains("daily") => StopReason::UserDailyQuota, - s if s.contains("weekly") => StopReason::UserWeeklyQuota, - s if s.contains("monthly") => StopReason::UserMonthlyQuota, + return match e { + QuotaError::DailyExceeded { .. } => StopReason::UserDailyQuota, + QuotaError::WeeklyExceeded { .. } => StopReason::UserWeeklyQuota, + QuotaError::MonthlyExceeded { .. } => StopReason::UserMonthlyQuota, _ => StopReason::UserDailyQuota, }; } @@ -169,13 +169,13 @@ impl QuotaEnforcer { StopReason::Running } - fn check_ip_with_session(&self, ip_str: &str, session_tx: u64, session_rx: u64) -> StopReason { + fn check_ip_with_session(&self, _ip_str: &str, _session_tx: u64, _session_rx: u64) -> StopReason { if let Err(e) = self.quota_mgr.check_ip(&self.ip, Direction::Both) { - return match format!("{}", e).as_str() { - s if s.contains("IP daily") => StopReason::IpDailyQuota, - s if s.contains("IP weekly") => StopReason::IpWeeklyQuota, - s if s.contains("IP monthly") => StopReason::IpMonthlyQuota, - s if s.contains("connections") => StopReason::IpDailyQuota, // reuse + return match e { + QuotaError::IpDailyExceeded { .. } | QuotaError::IpInboundDailyExceeded { .. } | QuotaError::IpOutboundDailyExceeded { .. } => StopReason::IpDailyQuota, + QuotaError::IpWeeklyExceeded { .. } | QuotaError::IpInboundWeeklyExceeded { .. } | QuotaError::IpOutboundWeeklyExceeded { .. } => StopReason::IpWeeklyQuota, + QuotaError::IpMonthlyExceeded { .. } | QuotaError::IpInboundMonthlyExceeded { .. } | QuotaError::IpOutboundMonthlyExceeded { .. } => StopReason::IpMonthlyQuota, + QuotaError::TooManyConnections { .. } => StopReason::IpDailyQuota, // reuse _ => StopReason::IpDailyQuota, }; } @@ -183,13 +183,13 @@ impl QuotaEnforcer { } /// Flush session bytes to DB. Call periodically and at session end. - pub fn flush_to_db(&self) { + pub fn flush_to_db(&self, ip_str: &str) { let tx = self.state.total_tx_bytes.load(Ordering::Relaxed); let rx = self.state.total_rx_bytes.load(Ordering::Relaxed); // From server perspective: tx = outbound (we sent), rx = inbound (we received) self.quota_mgr.record_usage( &self.username, - &self.ip.to_string(), + ip_str, rx, // inbound = what we received from client tx, // outbound = what we sent to client ); @@ -330,7 +330,7 @@ mod tests { qm, "testuser".into(), "127.0.0.1".parse().unwrap(), state, 10, 0, ); - enforcer.flush_to_db(); + enforcer.flush_to_db("127.0.0.1"); // flush_to_db: total_tx=5000→outbound, total_rx=3000→inbound // quota_mgr.record_usage(inbound=3000, outbound=5000) diff --git a/src/server_pro/user_db.rs b/src/server_pro/user_db.rs index 661931e..2eca89e 100644 --- a/src/server_pro/user_db.rs +++ b/src/server_pro/user_db.rs @@ -609,32 +609,20 @@ impl UserDb { fn hash_password(username: &str, password: &str) -> String { use sha2::{Sha256, Digest}; let mut hasher = Sha256::new(); - hasher.update(format!("{}:{}", username, password).as_bytes()); + hasher.update(username.as_bytes()); + hasher.update(b":"); + hasher.update(password.as_bytes()); let result = hasher.finalize(); - result.iter().map(|b| format!("{:02x}", b)).collect() + let mut hex = String::with_capacity(64); + for b in result { + use std::fmt::Write; + let _ = write!(hex, "{:02x}", b); + } + hex } fn chrono_date_today() -> String { - // Simple date without chrono crate - use std::time::{SystemTime, UNIX_EPOCH}; - let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(); - let days = secs / 86400; - let mut y = 1970u64; - let mut remaining = days; - loop { - let leap = if y % 4 == 0 && (y % 100 != 0 || y % 400 == 0) { 366 } else { 365 }; - if remaining < leap { break; } - remaining -= leap; - y += 1; - } - let leap = y % 4 == 0 && (y % 100 != 0 || y % 400 == 0); - let days_in_months = [31u64, if leap { 29 } else { 28 }, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; - let mut m = 0usize; - for i in 0..12 { - if remaining < days_in_months[i] { m = i; break; } - remaining -= days_in_months[i]; - } - format!("{:04}-{:02}-{:02}", y, m + 1, remaining + 1) + chrono::Local::now().format("%Y-%m-%d").to_string() } // Re-export for use by rusqlite diff --git a/src/syslog_logger.rs b/src/syslog_logger.rs index c658fdc..9c12133 100644 --- a/src/syslog_logger.rs +++ b/src/syslog_logger.rs @@ -36,12 +36,12 @@ pub fn init(target: &str) -> std::io::Result<()> { /// Send a syslog message with the given severity and message. /// Severity: 6=info, 4=warning, 3=error fn send(severity: u8, msg: &str) { + // Format timestamp outside the lock to minimize contention + let priority = 128 + severity; + let timestamp = bsd_timestamp(); + let guard = SYSLOG.lock().unwrap(); if let Some(ref sender) = *guard { - // RFC 3164 (BSD syslog): Mon DD HH:MM:SS hostname program: message - // facility=16 (local0) * 8 + severity - let priority = 128 + severity; - let timestamp = bsd_timestamp(); let syslog_msg = format!( "<{}>{} {} btest-rs: {}", priority, timestamp, sender.hostname, msg, @@ -52,44 +52,7 @@ fn send(severity: u8, msg: &str) { fn bsd_timestamp() -> String { // RFC 3164 format: "Mon DD HH:MM:SS" (no year) - use std::time::{SystemTime, UNIX_EPOCH}; - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - // Simple conversion — good enough for syslog - let secs_in_day = 86400u64; - let days = now / secs_in_day; - let time_of_day = now % secs_in_day; - let hours = time_of_day / 3600; - let minutes = (time_of_day % 3600) / 60; - let seconds = time_of_day % 60; - - // Day of year calculation (approximate months) - let months = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]; - let days_in_months = [31u64,28,31,30,31,30,31,31,30,31,30,31]; - - // Days since epoch to year/month/day - let mut y = 1970u64; - let mut remaining = days; - loop { - let leap = if y % 4 == 0 && (y % 100 != 0 || y % 400 == 0) { 366 } else { 365 }; - if remaining < leap { break; } - remaining -= leap; - y += 1; - } - let leap = y % 4 == 0 && (y % 100 != 0 || y % 400 == 0); - let mut m = 0usize; - for i in 0..12 { - let mut d = days_in_months[i]; - if i == 1 && leap { d += 1; } - if remaining < d { m = i; break; } - remaining -= d; - } - let day = remaining + 1; - - format!("{} {:2} {:02}:{:02}:{:02}", months[m], day, hours, minutes, seconds) + chrono::Local::now().format("%b %e %H:%M:%S").to_string() } // --- Public logging functions ---