diff --git a/src/client.rs b/src/client.rs index 2e59e40..7366987 100644 --- a/src/client.rs +++ b/src/client.rs @@ -27,6 +27,11 @@ pub async fn run_client( let mut stream = TcpStream::connect(&addr).await?; stream.set_nodelay(true)?; + // Set TCP socket buffers to 4MB for high throughput + let sock_ref = socket2::SockRef::from(&stream); + let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024); + let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024); + recv_hello(&mut stream).await?; tracing::info!("Connected to server"); @@ -154,15 +159,17 @@ async fn tcp_client_tx_loop( ) { tokio::time::sleep(Duration::from_millis(100)).await; - let packet = vec![0u8; tx_size]; // TCP data is all zeros let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16); + // Use larger writes when running unlimited to reduce syscall overhead + let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size }; + let packet = vec![0u8; effective_size]; // TCP data is all zeros let mut next_send = Instant::now(); while state.running.load(Ordering::Relaxed) { if writer.write_all(&packet).await.is_err() { break; } - state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed); + state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed); if state.tx_speed_changed.load(Ordering::Relaxed) { state.tx_speed_changed.store(false, Ordering::Relaxed); @@ -189,7 +196,7 @@ async fn tcp_client_rx_loop( mut reader: tokio::net::tcp::OwnedReadHalf, state: Arc, ) { - let mut buf = vec![0u8; 65536]; + let mut buf = vec![0u8; 256 * 1024]; while state.running.load(Ordering::Relaxed) { match reader.read(&mut buf).await { Ok(0) | Err(_) => break, diff --git a/src/ecsrp5.rs b/src/ecsrp5.rs index d8d0e13..badc7d1 100644 --- a/src/ecsrp5.rs +++ b/src/ecsrp5.rs @@ -6,6 +6,8 @@ //! //! btest framing: `[len:1][payload]` (no 0x06 handler byte, unlike Winbox). +use std::sync::LazyLock; + use num_bigint::BigUint; use num_integer::Integer; use num_traits::{One, Zero}; @@ -14,31 +16,31 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::protocol::{BtestError, Result}; -// --- Curve25519 parameters in Weierstrass form --- +// --- Curve25519 parameters in Weierstrass form (cached, computed once) --- -fn p() -> BigUint { +static P: LazyLock = LazyLock::new(|| { BigUint::parse_bytes( b"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", 16, ) .unwrap() -} +}); -fn curve_order() -> BigUint { +static CURVE_ORDER: LazyLock = LazyLock::new(|| { BigUint::parse_bytes( b"1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed", 16, ) .unwrap() -} +}); -fn weierstrass_a() -> BigUint { +static WEIERSTRASS_A: LazyLock = LazyLock::new(|| { BigUint::parse_bytes( b"2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144", 16, ) .unwrap() -} +}); const MONT_A: u64 = 486662; @@ -50,10 +52,10 @@ fn modinv(a: &BigUint, modulus: &BigUint) -> BigUint { a.modpow(&exp, modulus) } -fn legendre_symbol(a: &BigUint, p_val: &BigUint) -> i32 { - let exp = (p_val - BigUint::one()) / BigUint::from(2u32); - let l = a.modpow(&exp, p_val); - if l == p_val - BigUint::one() { +fn legendre_symbol(a: &BigUint, p: &BigUint) -> i32 { + let exp = (p - BigUint::one()) / BigUint::from(2u32); + let l = a.modpow(&exp, p); + if l == p - BigUint::one() { -1 } else if l == BigUint::zero() { 0 @@ -166,7 +168,7 @@ impl Point { } fn add(&self, other: &Point) -> Point { - let p_val = p(); + let p_val = &*P; if self.infinity { return other.clone(); } @@ -179,44 +181,44 @@ impl Point { let lam = if self.x == other.x && self.y == other.y { // Point doubling - let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &weierstrass_a()) % &p_val; - let two_y = (BigUint::from(2u32) * &self.y) % &p_val; - (three_x_sq * modinv(&two_y, &p_val)) % &p_val + let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &*WEIERSTRASS_A) % p_val; + let two_y = (BigUint::from(2u32) * &self.y) % p_val; + (three_x_sq * modinv(&two_y, p_val)) % p_val } else { // Point addition let dy = if other.y >= self.y { - (&other.y - &self.y) % &p_val + (&other.y - &self.y) % p_val } else { - (&p_val - (&self.y - &other.y) % &p_val) % &p_val + (p_val - (&self.y - &other.y) % p_val) % p_val }; let dx = if other.x >= self.x { - (&other.x - &self.x) % &p_val + (&other.x - &self.x) % p_val } else { - (&p_val - (&self.x - &other.x) % &p_val) % &p_val + (p_val - (&self.x - &other.x) % p_val) % p_val }; - (dy * modinv(&dx, &p_val)) % &p_val + (dy * modinv(&dx, p_val)) % p_val }; let x3 = { - let lam_sq = (&lam * &lam) % &p_val; - let sum_x = (&self.x + &other.x) % &p_val; + let lam_sq = (&lam * &lam) % p_val; + let sum_x = (&self.x + &other.x) % p_val; if lam_sq >= sum_x { - (lam_sq - sum_x) % &p_val + (lam_sq - sum_x) % p_val } else { - (&p_val - (sum_x - lam_sq) % &p_val) % &p_val + (p_val - (sum_x - lam_sq) % p_val) % p_val } }; let y3 = { let dx = if self.x >= x3 { - (&self.x - &x3) % &p_val + (&self.x - &x3) % p_val } else { - (&p_val - (&x3 - &self.x) % &p_val) % &p_val + (p_val - (&x3 - &self.x) % p_val) % p_val }; - let prod = (&lam * dx) % &p_val; + let prod = (&lam * dx) % p_val; if prod >= self.y { - (prod - &self.y) % &p_val + (prod - &self.y) % p_val } else { - (&p_val - (&self.y - prod) % &p_val) % &p_val + (p_val - (&self.y - prod) % p_val) % p_val } }; @@ -226,14 +228,13 @@ impl Point { fn scalar_mul(&self, scalar: &BigUint) -> Point { let mut result = Point::infinity(); let mut base = self.clone(); - let mut k = scalar.clone(); + let bits = scalar.bits(); - while !k.is_zero() { - if &k & &BigUint::one() == BigUint::one() { + for i in 0..bits { + if scalar.bit(i) { result = result.add(&base); } base = base.add(&base); - k >>= 1; } result } @@ -249,11 +250,11 @@ struct WCurve { impl WCurve { fn new() -> Self { - let p_val = p(); + let p_val = &*P; let mont_a = BigUint::from(MONT_A); - let three_inv = modinv(&BigUint::from(3u32), &p_val); - let conversion_from_m = (&mont_a * &three_inv) % &p_val; - let conversion_to_m = (&p_val - &conversion_from_m) % &p_val; + let three_inv = modinv(&BigUint::from(3u32), p_val); + let conversion_from_m = (&mont_a * &three_inv) % p_val; + let conversion_to_m = (p_val - &conversion_from_m) % p_val; let mut curve = WCurve { g: Point::infinity(), @@ -265,8 +266,8 @@ impl WCurve { } fn to_montgomery(&self, pt: &Point) -> ([u8; 32], u8) { - let p_val = p(); - let x = (&pt.x + &self.conversion_to_m) % &p_val; + let p_val = &*P; + let x = (&pt.x + &self.conversion_to_m) % p_val; let parity = if pt.y.bit(0) { 1u8 } else { 0u8 }; let mut bytes = [0u8; 32]; let x_bytes = x.to_bytes_be(); @@ -276,14 +277,14 @@ impl WCurve { } fn lift_x(&self, x_mont: &BigUint, parity: bool) -> Point { - let p_val = p(); - let x = x_mont % &p_val; + let p_val = &*P; + let x = x_mont % p_val; // y^2 = x^3 + Ax^2 + x (Montgomery) - let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % &p_val; + let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % p_val; // Convert x to Weierstrass - let x_w = (&x + &self.conversion_from_m) % &p_val; + let x_w = (&x + &self.conversion_from_m) % p_val; - if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, &p_val) { + if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, p_val) { let pt1 = Point::new(x_w.clone(), y1); let pt2 = Point::new(x_w, y2); if parity { @@ -323,7 +324,7 @@ impl WCurve { password: &str, salt: &[u8; 16], ) -> [u8; 32] { - let inner = sha256_bytes(&format!("{}:{}", username, password).as_bytes().to_vec()); + let inner = sha256_bytes(format!("{}:{}", username, password).as_bytes()); let mut input = Vec::with_capacity(16 + 32); input.extend_from_slice(salt); input.extend_from_slice(&inner); @@ -415,8 +416,8 @@ pub async fn client_authenticate( let i_int = BigUint::from_bytes_be(&i); let j_int = BigUint::from_bytes_be(&j); let s_a_int = BigUint::from_bytes_be(&s_a); - let order = curve_order(); - let scalar = ((&i_int * &j_int) + &s_a_int) % ℴ + let order = &*CURVE_ORDER; + let scalar = ((&i_int * &j_int) + &s_a_int) % order; let z_point = w_b_unblinded.scalar_mul(&scalar); let (z, _) = w.to_montgomery(&z_point); diff --git a/src/server.rs b/src/server.rs index 5ab6951..b550551 100644 --- a/src/server.rs +++ b/src/server.rs @@ -135,6 +135,11 @@ async fn handle_client( ) -> Result<()> { stream.set_nodelay(true)?; + // Set TCP socket buffers to 4MB (matching UDP path) for high throughput + let sock_ref = socket2::SockRef::from(&stream); + let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024); + let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024); + send_hello(&mut stream).await?; // Read 16-byte command (or whatever the client sends) @@ -575,8 +580,10 @@ async fn tcp_tx_loop_inner( ) { tokio::time::sleep(Duration::from_millis(100)).await; - let packet = vec![0u8; tx_size]; let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16); + // Use larger writes when running unlimited to reduce syscall overhead + let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size }; + let packet = vec![0u8; effective_size]; let mut next_send = Instant::now(); let mut next_status = Instant::now() + Duration::from_secs(1); let mut status_seq: u32 = 0; @@ -599,14 +606,14 @@ async fn tcp_tx_loop_inner( next_status = Instant::now() + Duration::from_secs(1); } - if !state.spend_budget(tx_size as u64) { + if !state.spend_budget(effective_size as u64) { break; } if writer.write_all(&packet).await.is_err() { state.running.store(false, Ordering::SeqCst); break; } - state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed); + state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed); if state.tx_speed_changed.load(Ordering::Relaxed) { state.tx_speed_changed.store(false, Ordering::Relaxed); @@ -630,7 +637,7 @@ async fn tcp_tx_loop_inner( } async fn tcp_rx_loop(mut reader: tokio::net::tcp::OwnedReadHalf, state: Arc) { - let mut buf = vec![0u8; 65536]; + let mut buf = vec![0u8; 256 * 1024]; while state.running.load(Ordering::Relaxed) { match reader.read(&mut buf).await { Ok(0) | Err(_) => { diff --git a/src/server_pro/ldap_auth.rs b/src/server_pro/ldap_auth.rs index d6f4b80..6387f50 100644 --- a/src/server_pro/ldap_auth.rs +++ b/src/server_pro/ldap_auth.rs @@ -15,6 +15,22 @@ pub struct LdapAuth { config: LdapConfig, } +/// Escape special characters in LDAP filter values per RFC 4515. +fn ldap_escape(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + for c in input.chars() { + match c { + '\\' => out.push_str("\\5c"), + '*' => out.push_str("\\2a"), + '(' => out.push_str("\\28"), + ')' => out.push_str("\\29"), + '\0' => out.push_str("\\00"), + _ => out.push(c), + } + } + out +} + impl LdapAuth { pub fn new(config: LdapConfig) -> Self { Self { config } @@ -26,6 +42,8 @@ impl LdapAuth { let (conn, mut ldap) = LdapConnAsync::new(&self.config.url).await?; ldap3::drive!(conn); + let safe_username = ldap_escape(username); + // If service account configured, bind first to search for user DN let user_dn = if let (Some(ref bind_dn), Some(ref bind_pass)) = (&self.config.bind_dn, &self.config.bind_pass) @@ -39,7 +57,7 @@ impl LdapAuth { // Search for the user let filter = format!( "(&(objectClass=person)(|(uid={})(sAMAccountName={})(cn={})))", - username, username, username + safe_username, safe_username, safe_username ); let (results, _) = ldap .search(&self.config.base_dn, Scope::Subtree, &filter, vec!["dn"]) @@ -51,11 +69,17 @@ impl LdapAuth { return Ok(false); } - let entry = SearchEntry::construct(results.into_iter().next().unwrap()); + let entry = match results.into_iter().next() { + Some(r) => SearchEntry::construct(r), + None => { + tracing::debug!("LDAP user not found: {}", username); + return Ok(false); + } + }; entry.dn } else { // No service account — construct DN directly - format!("uid={},{}", username, self.config.base_dn) + format!("uid={},{}", safe_username, self.config.base_dn) }; // Attempt user bind