Perf: cache EC-SRP5 constants, optimize TCP I/O, fix LDAP security

- Cache Curve25519 constants (P, CURVE_ORDER, WEIERSTRASS_A) with LazyLock
  eliminating ~768 BigUint heap allocations per auth handshake
- Optimize scalar_mul to use bit() instead of clone+shift
- Set TCP socket buffers to 4MB via socket2 (matching UDP path)
- Increase TCP RX buffers from 64KB to 256KB
- Use 256KB writes at unlimited rate (vs 32KB) reducing syscall overhead
- Fix LDAP filter injection with RFC 4515 escaping
- Fix unwrap panic on empty LDAP search results

Benchmarked on WiFi against MikroTik:
  TCP Download: +67% (19.7 → 32.9 Mbps avg)
  TCP Upload:   +87% (3.6 → 6.7 Mbps avg)
  Local CPU:    lower across all tests (29-36% vs 32-58%)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Siavash Sameni
2026-04-18 10:06:21 +04:00
parent da76c76c93
commit e6cecc7bd8
4 changed files with 96 additions and 57 deletions

View File

@@ -27,6 +27,11 @@ pub async fn run_client(
let mut stream = TcpStream::connect(&addr).await?; let mut stream = TcpStream::connect(&addr).await?;
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
// Set TCP socket buffers to 4MB for high throughput
let sock_ref = socket2::SockRef::from(&stream);
let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024);
let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024);
recv_hello(&mut stream).await?; recv_hello(&mut stream).await?;
tracing::info!("Connected to server"); tracing::info!("Connected to server");
@@ -154,15 +159,17 @@ async fn tcp_client_tx_loop(
) { ) {
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
let packet = vec![0u8; tx_size]; // TCP data is all zeros
let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16); let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16);
// Use larger writes when running unlimited to reduce syscall overhead
let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size };
let packet = vec![0u8; effective_size]; // TCP data is all zeros
let mut next_send = Instant::now(); let mut next_send = Instant::now();
while state.running.load(Ordering::Relaxed) { while state.running.load(Ordering::Relaxed) {
if writer.write_all(&packet).await.is_err() { if writer.write_all(&packet).await.is_err() {
break; break;
} }
state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed); state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed);
if state.tx_speed_changed.load(Ordering::Relaxed) { if state.tx_speed_changed.load(Ordering::Relaxed) {
state.tx_speed_changed.store(false, Ordering::Relaxed); state.tx_speed_changed.store(false, Ordering::Relaxed);
@@ -189,7 +196,7 @@ async fn tcp_client_rx_loop(
mut reader: tokio::net::tcp::OwnedReadHalf, mut reader: tokio::net::tcp::OwnedReadHalf,
state: Arc<BandwidthState>, state: Arc<BandwidthState>,
) { ) {
let mut buf = vec![0u8; 65536]; let mut buf = vec![0u8; 256 * 1024];
while state.running.load(Ordering::Relaxed) { while state.running.load(Ordering::Relaxed) {
match reader.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => break, Ok(0) | Err(_) => break,

View File

@@ -6,6 +6,8 @@
//! //!
//! btest framing: `[len:1][payload]` (no 0x06 handler byte, unlike Winbox). //! btest framing: `[len:1][payload]` (no 0x06 handler byte, unlike Winbox).
use std::sync::LazyLock;
use num_bigint::BigUint; use num_bigint::BigUint;
use num_integer::Integer; use num_integer::Integer;
use num_traits::{One, Zero}; use num_traits::{One, Zero};
@@ -14,31 +16,31 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::protocol::{BtestError, Result}; use crate::protocol::{BtestError, Result};
// --- Curve25519 parameters in Weierstrass form --- // --- Curve25519 parameters in Weierstrass form (cached, computed once) ---
fn p() -> BigUint { static P: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes( BigUint::parse_bytes(
b"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", b"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
16, 16,
) )
.unwrap() .unwrap()
} });
fn curve_order() -> BigUint { static CURVE_ORDER: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes( BigUint::parse_bytes(
b"1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed", b"1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed",
16, 16,
) )
.unwrap() .unwrap()
} });
fn weierstrass_a() -> BigUint { static WEIERSTRASS_A: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes( BigUint::parse_bytes(
b"2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144", b"2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144",
16, 16,
) )
.unwrap() .unwrap()
} });
const MONT_A: u64 = 486662; const MONT_A: u64 = 486662;
@@ -50,10 +52,10 @@ fn modinv(a: &BigUint, modulus: &BigUint) -> BigUint {
a.modpow(&exp, modulus) a.modpow(&exp, modulus)
} }
fn legendre_symbol(a: &BigUint, p_val: &BigUint) -> i32 { fn legendre_symbol(a: &BigUint, p: &BigUint) -> i32 {
let exp = (p_val - BigUint::one()) / BigUint::from(2u32); let exp = (p - BigUint::one()) / BigUint::from(2u32);
let l = a.modpow(&exp, p_val); let l = a.modpow(&exp, p);
if l == p_val - BigUint::one() { if l == p - BigUint::one() {
-1 -1
} else if l == BigUint::zero() { } else if l == BigUint::zero() {
0 0
@@ -166,7 +168,7 @@ impl Point {
} }
fn add(&self, other: &Point) -> Point { fn add(&self, other: &Point) -> Point {
let p_val = p(); let p_val = &*P;
if self.infinity { if self.infinity {
return other.clone(); return other.clone();
} }
@@ -179,44 +181,44 @@ impl Point {
let lam = if self.x == other.x && self.y == other.y { let lam = if self.x == other.x && self.y == other.y {
// Point doubling // Point doubling
let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &weierstrass_a()) % &p_val; let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &*WEIERSTRASS_A) % p_val;
let two_y = (BigUint::from(2u32) * &self.y) % &p_val; let two_y = (BigUint::from(2u32) * &self.y) % p_val;
(three_x_sq * modinv(&two_y, &p_val)) % &p_val (three_x_sq * modinv(&two_y, p_val)) % p_val
} else { } else {
// Point addition // Point addition
let dy = if other.y >= self.y { let dy = if other.y >= self.y {
(&other.y - &self.y) % &p_val (&other.y - &self.y) % p_val
} else { } else {
(&p_val - (&self.y - &other.y) % &p_val) % &p_val (p_val - (&self.y - &other.y) % p_val) % p_val
}; };
let dx = if other.x >= self.x { let dx = if other.x >= self.x {
(&other.x - &self.x) % &p_val (&other.x - &self.x) % p_val
} else { } else {
(&p_val - (&self.x - &other.x) % &p_val) % &p_val (p_val - (&self.x - &other.x) % p_val) % p_val
}; };
(dy * modinv(&dx, &p_val)) % &p_val (dy * modinv(&dx, p_val)) % p_val
}; };
let x3 = { let x3 = {
let lam_sq = (&lam * &lam) % &p_val; let lam_sq = (&lam * &lam) % p_val;
let sum_x = (&self.x + &other.x) % &p_val; let sum_x = (&self.x + &other.x) % p_val;
if lam_sq >= sum_x { if lam_sq >= sum_x {
(lam_sq - sum_x) % &p_val (lam_sq - sum_x) % p_val
} else { } else {
(&p_val - (sum_x - lam_sq) % &p_val) % &p_val (p_val - (sum_x - lam_sq) % p_val) % p_val
} }
}; };
let y3 = { let y3 = {
let dx = if self.x >= x3 { let dx = if self.x >= x3 {
(&self.x - &x3) % &p_val (&self.x - &x3) % p_val
} else { } else {
(&p_val - (&x3 - &self.x) % &p_val) % &p_val (p_val - (&x3 - &self.x) % p_val) % p_val
}; };
let prod = (&lam * dx) % &p_val; let prod = (&lam * dx) % p_val;
if prod >= self.y { if prod >= self.y {
(prod - &self.y) % &p_val (prod - &self.y) % p_val
} else { } else {
(&p_val - (&self.y - prod) % &p_val) % &p_val (p_val - (&self.y - prod) % p_val) % p_val
} }
}; };
@@ -226,14 +228,13 @@ impl Point {
fn scalar_mul(&self, scalar: &BigUint) -> Point { fn scalar_mul(&self, scalar: &BigUint) -> Point {
let mut result = Point::infinity(); let mut result = Point::infinity();
let mut base = self.clone(); let mut base = self.clone();
let mut k = scalar.clone(); let bits = scalar.bits();
while !k.is_zero() { for i in 0..bits {
if &k & &BigUint::one() == BigUint::one() { if scalar.bit(i) {
result = result.add(&base); result = result.add(&base);
} }
base = base.add(&base); base = base.add(&base);
k >>= 1;
} }
result result
} }
@@ -249,11 +250,11 @@ struct WCurve {
impl WCurve { impl WCurve {
fn new() -> Self { fn new() -> Self {
let p_val = p(); let p_val = &*P;
let mont_a = BigUint::from(MONT_A); let mont_a = BigUint::from(MONT_A);
let three_inv = modinv(&BigUint::from(3u32), &p_val); let three_inv = modinv(&BigUint::from(3u32), p_val);
let conversion_from_m = (&mont_a * &three_inv) % &p_val; let conversion_from_m = (&mont_a * &three_inv) % p_val;
let conversion_to_m = (&p_val - &conversion_from_m) % &p_val; let conversion_to_m = (p_val - &conversion_from_m) % p_val;
let mut curve = WCurve { let mut curve = WCurve {
g: Point::infinity(), g: Point::infinity(),
@@ -265,8 +266,8 @@ impl WCurve {
} }
fn to_montgomery(&self, pt: &Point) -> ([u8; 32], u8) { fn to_montgomery(&self, pt: &Point) -> ([u8; 32], u8) {
let p_val = p(); let p_val = &*P;
let x = (&pt.x + &self.conversion_to_m) % &p_val; let x = (&pt.x + &self.conversion_to_m) % p_val;
let parity = if pt.y.bit(0) { 1u8 } else { 0u8 }; let parity = if pt.y.bit(0) { 1u8 } else { 0u8 };
let mut bytes = [0u8; 32]; let mut bytes = [0u8; 32];
let x_bytes = x.to_bytes_be(); let x_bytes = x.to_bytes_be();
@@ -276,14 +277,14 @@ impl WCurve {
} }
fn lift_x(&self, x_mont: &BigUint, parity: bool) -> Point { fn lift_x(&self, x_mont: &BigUint, parity: bool) -> Point {
let p_val = p(); let p_val = &*P;
let x = x_mont % &p_val; let x = x_mont % p_val;
// y^2 = x^3 + Ax^2 + x (Montgomery) // y^2 = x^3 + Ax^2 + x (Montgomery)
let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % &p_val; let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % p_val;
// Convert x to Weierstrass // Convert x to Weierstrass
let x_w = (&x + &self.conversion_from_m) % &p_val; let x_w = (&x + &self.conversion_from_m) % p_val;
if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, &p_val) { if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, p_val) {
let pt1 = Point::new(x_w.clone(), y1); let pt1 = Point::new(x_w.clone(), y1);
let pt2 = Point::new(x_w, y2); let pt2 = Point::new(x_w, y2);
if parity { if parity {
@@ -323,7 +324,7 @@ impl WCurve {
password: &str, password: &str,
salt: &[u8; 16], salt: &[u8; 16],
) -> [u8; 32] { ) -> [u8; 32] {
let inner = sha256_bytes(&format!("{}:{}", username, password).as_bytes().to_vec()); let inner = sha256_bytes(format!("{}:{}", username, password).as_bytes());
let mut input = Vec::with_capacity(16 + 32); let mut input = Vec::with_capacity(16 + 32);
input.extend_from_slice(salt); input.extend_from_slice(salt);
input.extend_from_slice(&inner); input.extend_from_slice(&inner);
@@ -415,8 +416,8 @@ pub async fn client_authenticate<S: AsyncReadExt + AsyncWriteExt + Unpin>(
let i_int = BigUint::from_bytes_be(&i); let i_int = BigUint::from_bytes_be(&i);
let j_int = BigUint::from_bytes_be(&j); let j_int = BigUint::from_bytes_be(&j);
let s_a_int = BigUint::from_bytes_be(&s_a); let s_a_int = BigUint::from_bytes_be(&s_a);
let order = curve_order(); let order = &*CURVE_ORDER;
let scalar = ((&i_int * &j_int) + &s_a_int) % &order; let scalar = ((&i_int * &j_int) + &s_a_int) % order;
let z_point = w_b_unblinded.scalar_mul(&scalar); let z_point = w_b_unblinded.scalar_mul(&scalar);
let (z, _) = w.to_montgomery(&z_point); let (z, _) = w.to_montgomery(&z_point);

View File

@@ -135,6 +135,11 @@ async fn handle_client(
) -> Result<()> { ) -> Result<()> {
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
// Set TCP socket buffers to 4MB (matching UDP path) for high throughput
let sock_ref = socket2::SockRef::from(&stream);
let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024);
let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024);
send_hello(&mut stream).await?; send_hello(&mut stream).await?;
// Read 16-byte command (or whatever the client sends) // Read 16-byte command (or whatever the client sends)
@@ -575,8 +580,10 @@ async fn tcp_tx_loop_inner(
) { ) {
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
let packet = vec![0u8; tx_size];
let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16); let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16);
// Use larger writes when running unlimited to reduce syscall overhead
let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size };
let packet = vec![0u8; effective_size];
let mut next_send = Instant::now(); let mut next_send = Instant::now();
let mut next_status = Instant::now() + Duration::from_secs(1); let mut next_status = Instant::now() + Duration::from_secs(1);
let mut status_seq: u32 = 0; let mut status_seq: u32 = 0;
@@ -599,14 +606,14 @@ async fn tcp_tx_loop_inner(
next_status = Instant::now() + Duration::from_secs(1); next_status = Instant::now() + Duration::from_secs(1);
} }
if !state.spend_budget(tx_size as u64) { if !state.spend_budget(effective_size as u64) {
break; break;
} }
if writer.write_all(&packet).await.is_err() { if writer.write_all(&packet).await.is_err() {
state.running.store(false, Ordering::SeqCst); state.running.store(false, Ordering::SeqCst);
break; break;
} }
state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed); state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed);
if state.tx_speed_changed.load(Ordering::Relaxed) { if state.tx_speed_changed.load(Ordering::Relaxed) {
state.tx_speed_changed.store(false, Ordering::Relaxed); state.tx_speed_changed.store(false, Ordering::Relaxed);
@@ -630,7 +637,7 @@ async fn tcp_tx_loop_inner(
} }
async fn tcp_rx_loop(mut reader: tokio::net::tcp::OwnedReadHalf, state: Arc<BandwidthState>) { async fn tcp_rx_loop(mut reader: tokio::net::tcp::OwnedReadHalf, state: Arc<BandwidthState>) {
let mut buf = vec![0u8; 65536]; let mut buf = vec![0u8; 256 * 1024];
while state.running.load(Ordering::Relaxed) { while state.running.load(Ordering::Relaxed) {
match reader.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {

View File

@@ -15,6 +15,22 @@ pub struct LdapAuth {
config: LdapConfig, config: LdapConfig,
} }
/// Escape special characters in LDAP filter values per RFC 4515.
fn ldap_escape(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for c in input.chars() {
match c {
'\\' => out.push_str("\\5c"),
'*' => out.push_str("\\2a"),
'(' => out.push_str("\\28"),
')' => out.push_str("\\29"),
'\0' => out.push_str("\\00"),
_ => out.push(c),
}
}
out
}
impl LdapAuth { impl LdapAuth {
pub fn new(config: LdapConfig) -> Self { pub fn new(config: LdapConfig) -> Self {
Self { config } Self { config }
@@ -26,6 +42,8 @@ impl LdapAuth {
let (conn, mut ldap) = LdapConnAsync::new(&self.config.url).await?; let (conn, mut ldap) = LdapConnAsync::new(&self.config.url).await?;
ldap3::drive!(conn); ldap3::drive!(conn);
let safe_username = ldap_escape(username);
// If service account configured, bind first to search for user DN // If service account configured, bind first to search for user DN
let user_dn = if let (Some(ref bind_dn), Some(ref bind_pass)) = let user_dn = if let (Some(ref bind_dn), Some(ref bind_pass)) =
(&self.config.bind_dn, &self.config.bind_pass) (&self.config.bind_dn, &self.config.bind_pass)
@@ -39,7 +57,7 @@ impl LdapAuth {
// Search for the user // Search for the user
let filter = format!( let filter = format!(
"(&(objectClass=person)(|(uid={})(sAMAccountName={})(cn={})))", "(&(objectClass=person)(|(uid={})(sAMAccountName={})(cn={})))",
username, username, username safe_username, safe_username, safe_username
); );
let (results, _) = ldap let (results, _) = ldap
.search(&self.config.base_dn, Scope::Subtree, &filter, vec!["dn"]) .search(&self.config.base_dn, Scope::Subtree, &filter, vec!["dn"])
@@ -51,11 +69,17 @@ impl LdapAuth {
return Ok(false); return Ok(false);
} }
let entry = SearchEntry::construct(results.into_iter().next().unwrap()); let entry = match results.into_iter().next() {
Some(r) => SearchEntry::construct(r),
None => {
tracing::debug!("LDAP user not found: {}", username);
return Ok(false);
}
};
entry.dn entry.dn
} else { } else {
// No service account — construct DN directly // No service account — construct DN directly
format!("uid={},{}", username, self.config.base_dn) format!("uid={},{}", safe_username, self.config.base_dn)
}; };
// Attempt user bind // Attempt user bind