Perf: cache EC-SRP5 constants, optimize TCP I/O, fix LDAP security

- Cache Curve25519 constants (P, CURVE_ORDER, WEIERSTRASS_A) with LazyLock
  eliminating ~768 BigUint heap allocations per auth handshake
- Optimize scalar_mul to use bit() instead of clone+shift
- Set TCP socket buffers to 4MB via socket2 (matching UDP path)
- Increase TCP RX buffers from 64KB to 256KB
- Use 256KB writes at unlimited rate (vs 32KB) reducing syscall overhead
- Fix LDAP filter injection with RFC 4515 escaping
- Fix unwrap panic on empty LDAP search results

Benchmarked on WiFi against MikroTik:
  TCP Download: +67% (19.7 → 32.9 Mbps avg)
  TCP Upload:   +87% (3.6 → 6.7 Mbps avg)
  Local CPU:    lower across all tests (29-36% vs 32-58%)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Siavash Sameni
2026-04-18 10:06:21 +04:00
parent da76c76c93
commit e6cecc7bd8
4 changed files with 96 additions and 57 deletions

View File

@@ -27,6 +27,11 @@ pub async fn run_client(
let mut stream = TcpStream::connect(&addr).await?;
stream.set_nodelay(true)?;
// Set TCP socket buffers to 4MB for high throughput
let sock_ref = socket2::SockRef::from(&stream);
let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024);
let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024);
recv_hello(&mut stream).await?;
tracing::info!("Connected to server");
@@ -154,15 +159,17 @@ async fn tcp_client_tx_loop(
) {
tokio::time::sleep(Duration::from_millis(100)).await;
let packet = vec![0u8; tx_size]; // TCP data is all zeros
let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16);
// Use larger writes when running unlimited to reduce syscall overhead
let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size };
let packet = vec![0u8; effective_size]; // TCP data is all zeros
let mut next_send = Instant::now();
while state.running.load(Ordering::Relaxed) {
if writer.write_all(&packet).await.is_err() {
break;
}
state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed);
state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed);
if state.tx_speed_changed.load(Ordering::Relaxed) {
state.tx_speed_changed.store(false, Ordering::Relaxed);
@@ -189,7 +196,7 @@ async fn tcp_client_rx_loop(
mut reader: tokio::net::tcp::OwnedReadHalf,
state: Arc<BandwidthState>,
) {
let mut buf = vec![0u8; 65536];
let mut buf = vec![0u8; 256 * 1024];
while state.running.load(Ordering::Relaxed) {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => break,

View File

@@ -6,6 +6,8 @@
//!
//! btest framing: `[len:1][payload]` (no 0x06 handler byte, unlike Winbox).
use std::sync::LazyLock;
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::{One, Zero};
@@ -14,31 +16,31 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::protocol::{BtestError, Result};
// --- Curve25519 parameters in Weierstrass form ---
// --- Curve25519 parameters in Weierstrass form (cached, computed once) ---
fn p() -> BigUint {
static P: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes(
b"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
16,
)
.unwrap()
}
});
fn curve_order() -> BigUint {
static CURVE_ORDER: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes(
b"1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed",
16,
)
.unwrap()
}
});
fn weierstrass_a() -> BigUint {
static WEIERSTRASS_A: LazyLock<BigUint> = LazyLock::new(|| {
BigUint::parse_bytes(
b"2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144",
16,
)
.unwrap()
}
});
const MONT_A: u64 = 486662;
@@ -50,10 +52,10 @@ fn modinv(a: &BigUint, modulus: &BigUint) -> BigUint {
a.modpow(&exp, modulus)
}
fn legendre_symbol(a: &BigUint, p_val: &BigUint) -> i32 {
let exp = (p_val - BigUint::one()) / BigUint::from(2u32);
let l = a.modpow(&exp, p_val);
if l == p_val - BigUint::one() {
fn legendre_symbol(a: &BigUint, p: &BigUint) -> i32 {
let exp = (p - BigUint::one()) / BigUint::from(2u32);
let l = a.modpow(&exp, p);
if l == p - BigUint::one() {
-1
} else if l == BigUint::zero() {
0
@@ -166,7 +168,7 @@ impl Point {
}
fn add(&self, other: &Point) -> Point {
let p_val = p();
let p_val = &*P;
if self.infinity {
return other.clone();
}
@@ -179,44 +181,44 @@ impl Point {
let lam = if self.x == other.x && self.y == other.y {
// Point doubling
let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &weierstrass_a()) % &p_val;
let two_y = (BigUint::from(2u32) * &self.y) % &p_val;
(three_x_sq * modinv(&two_y, &p_val)) % &p_val
let three_x_sq = (BigUint::from(3u32) * &self.x * &self.x + &*WEIERSTRASS_A) % p_val;
let two_y = (BigUint::from(2u32) * &self.y) % p_val;
(three_x_sq * modinv(&two_y, p_val)) % p_val
} else {
// Point addition
let dy = if other.y >= self.y {
(&other.y - &self.y) % &p_val
(&other.y - &self.y) % p_val
} else {
(&p_val - (&self.y - &other.y) % &p_val) % &p_val
(p_val - (&self.y - &other.y) % p_val) % p_val
};
let dx = if other.x >= self.x {
(&other.x - &self.x) % &p_val
(&other.x - &self.x) % p_val
} else {
(&p_val - (&self.x - &other.x) % &p_val) % &p_val
(p_val - (&self.x - &other.x) % p_val) % p_val
};
(dy * modinv(&dx, &p_val)) % &p_val
(dy * modinv(&dx, p_val)) % p_val
};
let x3 = {
let lam_sq = (&lam * &lam) % &p_val;
let sum_x = (&self.x + &other.x) % &p_val;
let lam_sq = (&lam * &lam) % p_val;
let sum_x = (&self.x + &other.x) % p_val;
if lam_sq >= sum_x {
(lam_sq - sum_x) % &p_val
(lam_sq - sum_x) % p_val
} else {
(&p_val - (sum_x - lam_sq) % &p_val) % &p_val
(p_val - (sum_x - lam_sq) % p_val) % p_val
}
};
let y3 = {
let dx = if self.x >= x3 {
(&self.x - &x3) % &p_val
(&self.x - &x3) % p_val
} else {
(&p_val - (&x3 - &self.x) % &p_val) % &p_val
(p_val - (&x3 - &self.x) % p_val) % p_val
};
let prod = (&lam * dx) % &p_val;
let prod = (&lam * dx) % p_val;
if prod >= self.y {
(prod - &self.y) % &p_val
(prod - &self.y) % p_val
} else {
(&p_val - (&self.y - prod) % &p_val) % &p_val
(p_val - (&self.y - prod) % p_val) % p_val
}
};
@@ -226,14 +228,13 @@ impl Point {
fn scalar_mul(&self, scalar: &BigUint) -> Point {
let mut result = Point::infinity();
let mut base = self.clone();
let mut k = scalar.clone();
let bits = scalar.bits();
while !k.is_zero() {
if &k & &BigUint::one() == BigUint::one() {
for i in 0..bits {
if scalar.bit(i) {
result = result.add(&base);
}
base = base.add(&base);
k >>= 1;
}
result
}
@@ -249,11 +250,11 @@ struct WCurve {
impl WCurve {
fn new() -> Self {
let p_val = p();
let p_val = &*P;
let mont_a = BigUint::from(MONT_A);
let three_inv = modinv(&BigUint::from(3u32), &p_val);
let conversion_from_m = (&mont_a * &three_inv) % &p_val;
let conversion_to_m = (&p_val - &conversion_from_m) % &p_val;
let three_inv = modinv(&BigUint::from(3u32), p_val);
let conversion_from_m = (&mont_a * &three_inv) % p_val;
let conversion_to_m = (p_val - &conversion_from_m) % p_val;
let mut curve = WCurve {
g: Point::infinity(),
@@ -265,8 +266,8 @@ impl WCurve {
}
fn to_montgomery(&self, pt: &Point) -> ([u8; 32], u8) {
let p_val = p();
let x = (&pt.x + &self.conversion_to_m) % &p_val;
let p_val = &*P;
let x = (&pt.x + &self.conversion_to_m) % p_val;
let parity = if pt.y.bit(0) { 1u8 } else { 0u8 };
let mut bytes = [0u8; 32];
let x_bytes = x.to_bytes_be();
@@ -276,14 +277,14 @@ impl WCurve {
}
fn lift_x(&self, x_mont: &BigUint, parity: bool) -> Point {
let p_val = p();
let x = x_mont % &p_val;
let p_val = &*P;
let x = x_mont % p_val;
// y^2 = x^3 + Ax^2 + x (Montgomery)
let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % &p_val;
let y_squared = (&x * &x * &x + BigUint::from(MONT_A) * &x * &x + &x) % p_val;
// Convert x to Weierstrass
let x_w = (&x + &self.conversion_from_m) % &p_val;
let x_w = (&x + &self.conversion_from_m) % p_val;
if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, &p_val) {
if let Some((y1, y2)) = prime_mod_sqrt(&y_squared, p_val) {
let pt1 = Point::new(x_w.clone(), y1);
let pt2 = Point::new(x_w, y2);
if parity {
@@ -323,7 +324,7 @@ impl WCurve {
password: &str,
salt: &[u8; 16],
) -> [u8; 32] {
let inner = sha256_bytes(&format!("{}:{}", username, password).as_bytes().to_vec());
let inner = sha256_bytes(format!("{}:{}", username, password).as_bytes());
let mut input = Vec::with_capacity(16 + 32);
input.extend_from_slice(salt);
input.extend_from_slice(&inner);
@@ -415,8 +416,8 @@ pub async fn client_authenticate<S: AsyncReadExt + AsyncWriteExt + Unpin>(
let i_int = BigUint::from_bytes_be(&i);
let j_int = BigUint::from_bytes_be(&j);
let s_a_int = BigUint::from_bytes_be(&s_a);
let order = curve_order();
let scalar = ((&i_int * &j_int) + &s_a_int) % &order;
let order = &*CURVE_ORDER;
let scalar = ((&i_int * &j_int) + &s_a_int) % order;
let z_point = w_b_unblinded.scalar_mul(&scalar);
let (z, _) = w.to_montgomery(&z_point);

View File

@@ -135,6 +135,11 @@ async fn handle_client(
) -> Result<()> {
stream.set_nodelay(true)?;
// Set TCP socket buffers to 4MB (matching UDP path) for high throughput
let sock_ref = socket2::SockRef::from(&stream);
let _ = sock_ref.set_send_buffer_size(4 * 1024 * 1024);
let _ = sock_ref.set_recv_buffer_size(4 * 1024 * 1024);
send_hello(&mut stream).await?;
// Read 16-byte command (or whatever the client sends)
@@ -575,8 +580,10 @@ async fn tcp_tx_loop_inner(
) {
tokio::time::sleep(Duration::from_millis(100)).await;
let packet = vec![0u8; tx_size];
let mut interval = bandwidth::calc_send_interval(tx_speed, tx_size as u16);
// Use larger writes when running unlimited to reduce syscall overhead
let effective_size = if interval.is_none() { tx_size.max(256 * 1024) } else { tx_size };
let packet = vec![0u8; effective_size];
let mut next_send = Instant::now();
let mut next_status = Instant::now() + Duration::from_secs(1);
let mut status_seq: u32 = 0;
@@ -599,14 +606,14 @@ async fn tcp_tx_loop_inner(
next_status = Instant::now() + Duration::from_secs(1);
}
if !state.spend_budget(tx_size as u64) {
if !state.spend_budget(effective_size as u64) {
break;
}
if writer.write_all(&packet).await.is_err() {
state.running.store(false, Ordering::SeqCst);
break;
}
state.tx_bytes.fetch_add(tx_size as u64, Ordering::Relaxed);
state.tx_bytes.fetch_add(effective_size as u64, Ordering::Relaxed);
if state.tx_speed_changed.load(Ordering::Relaxed) {
state.tx_speed_changed.store(false, Ordering::Relaxed);
@@ -630,7 +637,7 @@ async fn tcp_tx_loop_inner(
}
async fn tcp_rx_loop(mut reader: tokio::net::tcp::OwnedReadHalf, state: Arc<BandwidthState>) {
let mut buf = vec![0u8; 65536];
let mut buf = vec![0u8; 256 * 1024];
while state.running.load(Ordering::Relaxed) {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => {

View File

@@ -15,6 +15,22 @@ pub struct LdapAuth {
config: LdapConfig,
}
/// Escape special characters in LDAP filter values per RFC 4515.
fn ldap_escape(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for c in input.chars() {
match c {
'\\' => out.push_str("\\5c"),
'*' => out.push_str("\\2a"),
'(' => out.push_str("\\28"),
')' => out.push_str("\\29"),
'\0' => out.push_str("\\00"),
_ => out.push(c),
}
}
out
}
impl LdapAuth {
pub fn new(config: LdapConfig) -> Self {
Self { config }
@@ -26,6 +42,8 @@ impl LdapAuth {
let (conn, mut ldap) = LdapConnAsync::new(&self.config.url).await?;
ldap3::drive!(conn);
let safe_username = ldap_escape(username);
// If service account configured, bind first to search for user DN
let user_dn = if let (Some(ref bind_dn), Some(ref bind_pass)) =
(&self.config.bind_dn, &self.config.bind_pass)
@@ -39,7 +57,7 @@ impl LdapAuth {
// Search for the user
let filter = format!(
"(&(objectClass=person)(|(uid={})(sAMAccountName={})(cn={})))",
username, username, username
safe_username, safe_username, safe_username
);
let (results, _) = ldap
.search(&self.config.base_dn, Scope::Subtree, &filter, vec!["dn"])
@@ -51,11 +69,17 @@ impl LdapAuth {
return Ok(false);
}
let entry = SearchEntry::construct(results.into_iter().next().unwrap());
let entry = match results.into_iter().next() {
Some(r) => SearchEntry::construct(r),
None => {
tracing::debug!("LDAP user not found: {}", username);
return Ok(false);
}
};
entry.dn
} else {
// No service account — construct DN directly
format!("uid={},{}", username, self.config.base_dn)
format!("uid={},{}", safe_username, self.config.base_dn)
};
// Attempt user bind