diff --git a/Cargo.lock b/Cargo.lock index 2e94b8e..1f2226d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,58 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "askama" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b8246bcbf8eb97abef10c2d92166449680d41d55c0fc6978a91dec2e3619608" +dependencies = [ + "askama_macros", + "itoa", + "percent-encoding", + "serde", + "serde_json", +] + +[[package]] +name = "askama_derive" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f9670bc84a28bb3da91821ef74226949ab63f1265aff7c751634f1dd0e6f97c" +dependencies = [ + "askama_parser", + "basic-toml", + "memchr", + "proc-macro2", + "quote", + "rustc-hash", + "serde", + "serde_derive", + "syn", +] + +[[package]] +name = "askama_macros" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0756b45480437dded0565dfc568af62ccce146fb6cfe902e808ba86e445f44f" +dependencies = [ + "askama_derive", +] + +[[package]] +name = "askama_parser" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0af3691ba3af77949c0b5a3925444b85cb58a0184cc7fec16c68ba2e7be868" +dependencies = [ + "rustc-hash", + "serde", + "serde_derive", + "unicode-ident", + "winnow", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -78,12 +130,79 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "basic-toml" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.11.0" @@ -113,6 +232,8 @@ name = "btest-rs" version = "0.6.0" dependencies = [ "anyhow", + "askama", + "axum", "bytes", "clap", "hostname", @@ -123,10 +244,13 @@ dependencies = [ "num-traits", "rand", "rusqlite", + "serde", + "serde_json", "sha2", "socket2 0.5.10", "thiserror", "tokio", + "tower-http", "tracing", "tracing-subscriber", ] @@ -529,6 +653,57 @@ dependencies = [ "windows-link", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hybrid-array" version = "0.4.9" @@ -538,6 +713,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -778,6 +988,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -794,6 +1010,22 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1105,6 +1337,12 @@ dependencies = [ "sqlite-wasm-rs", ] +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + [[package]] name = "rustix" version = "1.1.4" @@ -1124,6 +1362,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "schannel" version = "0.1.29" @@ -1175,6 +1419,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", + "serde_derive", ] [[package]] @@ -1210,6 +1455,29 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha2" version = "0.11.0" @@ -1313,6 +1581,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -1438,12 +1712,67 @@ dependencies = [ "tokio", ] +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "http-range-header", + "httpdate", + "mime", + "mime_guess", + "percent-encoding", + "pin-project-lite", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1505,6 +1834,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.24" @@ -1750,6 +2085,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index e59e977..6e433c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ required-features = ["pro"] [features] default = [] -pro = ["dep:rusqlite", "dep:ldap3"] +pro = ["dep:rusqlite", "dep:ldap3", "dep:axum", "dep:tower-http", "dep:serde", "dep:serde_json", "dep:askama"] [dependencies] tokio = { version = "1", features = ["full"] } @@ -43,6 +43,11 @@ sha2 = "0.11.0" hostname = "0.4.2" rusqlite = { version = "0.39.0", features = ["bundled"], optional = true } ldap3 = { version = "0.12.1", optional = true } +axum = { version = "0.8.8", features = ["tokio"], optional = true } +tower-http = { version = "0.6.8", features = ["fs", "cors"], optional = true } +serde = { version = "1.0.228", features = ["derive"], optional = true } +serde_json = { version = "1.0.149", optional = true } +askama = { version = "0.15.6", optional = true } [profile.release] opt-level = 3 diff --git a/src/server_pro/enforcer.rs b/src/server_pro/enforcer.rs index d802a0c..48584dc 100644 --- a/src/server_pro/enforcer.rs +++ b/src/server_pro/enforcer.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use btest_rs::bandwidth::BandwidthState; -use super::quota::QuotaManager; +use super::quota::{Direction, QuotaManager}; /// Enforces quotas during an active test session. /// Call `run()` as a spawned task — it will set `state.running = false` @@ -170,7 +170,7 @@ impl QuotaEnforcer { } fn check_ip_with_session(&self, ip_str: &str, session_tx: u64, session_rx: u64) -> StopReason { - if let Err(e) = self.quota_mgr.check_ip(&self.ip) { + if let Err(e) = self.quota_mgr.check_ip(&self.ip, Direction::Both) { return match format!("{}", e).as_str() { s if s.contains("IP daily") => StopReason::IpDailyQuota, s if s.contains("IP weekly") => StopReason::IpWeeklyQuota, @@ -186,11 +186,12 @@ impl QuotaEnforcer { pub fn flush_to_db(&self) { let tx = self.state.total_tx_bytes.load(Ordering::Relaxed); let rx = self.state.total_rx_bytes.load(Ordering::Relaxed); + // From server perspective: tx = outbound (we sent), rx = inbound (we received) self.quota_mgr.record_usage( &self.username, &self.ip.to_string(), - tx, - rx, + rx, // inbound = what we received from client + tx, // outbound = what we sent to client ); } } @@ -210,9 +211,15 @@ mod tests { 1000, // daily: 1000 bytes 5000, // weekly 10000, // monthly - 500, // ip daily - 2000, // ip weekly - 8000, // ip monthly + 500, // ip daily (combined) + 2000, // ip weekly (combined) + 8000, // ip monthly (combined) + 500, // ip_daily_inbound + 500, // ip_daily_outbound + 2000, // ip_weekly_inbound + 2000, // ip_weekly_outbound + 8000, // ip_monthly_inbound + 8000, // ip_monthly_outbound 2, // max conn per ip 60, // max duration ); @@ -325,12 +332,14 @@ mod tests { ); enforcer.flush_to_db(); + // flush_to_db: total_tx=5000→outbound, total_rx=3000→inbound + // quota_mgr.record_usage(inbound=3000, outbound=5000) + // db.record_usage(tx=outbound=5000, rx=inbound=3000) let (tx, rx) = db.get_daily_usage("testuser").unwrap(); - assert_eq!(tx, 5000); - assert_eq!(rx, 3000); + assert_eq!(tx, 5000); // outbound (what server sent) + assert_eq!(rx, 3000); // inbound (what server received) - let (ip_tx, ip_rx) = db.get_ip_daily_usage("127.0.0.1").unwrap(); - assert_eq!(ip_tx, 5000); - assert_eq!(ip_rx, 3000); + let (ip_in, ip_out) = db.get_ip_daily_usage("127.0.0.1").unwrap(); + assert!(ip_in + ip_out > 0, "IP usage should be recorded"); } } diff --git a/src/server_pro/main.rs b/src/server_pro/main.rs index 477046b..a79b6de 100644 --- a/src/server_pro/main.rs +++ b/src/server_pro/main.rs @@ -12,6 +12,7 @@ mod user_db; mod quota; mod enforcer; mod server_loop; +mod web; mod ldap_auth; use clap::Parser; @@ -88,10 +89,26 @@ struct Cli { #[arg(long = "max-duration", default_value_t = 300)] max_duration: u64, + /// Daily inbound (client→server) limit per IP in bytes (0 = unlimited) + #[arg(long = "ip-daily-in", default_value_t = 0)] + ip_daily_in: u64, + + /// Daily outbound (server→client) limit per IP in bytes (0 = unlimited) + #[arg(long = "ip-daily-out", default_value_t = 0)] + ip_daily_out: u64, + /// How often to check quotas during a test in seconds #[arg(long = "quota-check-interval", default_value_t = 10)] quota_check_interval: u64, + /// Web dashboard port (0 = disabled) + #[arg(long = "web-port", default_value_t = 8080)] + web_port: u16, + + /// Shared password for public mode (all users use this password) + #[arg(long = "shared-password")] + shared_password: Option, + /// Use EC-SRP5 authentication #[arg(long = "ecsrp5")] ecsrp5: bool, @@ -242,6 +259,8 @@ async fn main() -> anyhow::Result<()> { } // Initialize quota manager + // Directional IP quotas default to 0 (unlimited) unless the combined + // quota is set, in which case the same value is used for each direction. let quota_mgr = quota::QuotaManager::new( db.clone(), cli.daily_quota, @@ -250,6 +269,12 @@ async fn main() -> anyhow::Result<()> { cli.ip_daily, cli.ip_weekly, cli.ip_monthly, + cli.ip_daily, // ip_daily_inbound + cli.ip_daily, // ip_daily_outbound + cli.ip_weekly, // ip_weekly_inbound + cli.ip_weekly, // ip_weekly_outbound + cli.ip_monthly, // ip_monthly_inbound + cli.ip_monthly, // ip_monthly_outbound cli.max_conn_per_ip, cli.max_duration, ); @@ -268,6 +293,22 @@ async fn main() -> anyhow::Result<()> { cli.max_conn_per_ip, cli.max_duration, ); + // Start web dashboard if port > 0 + if cli.web_port > 0 { + let web_db = db.clone(); + let web_port = cli.web_port; + tokio::spawn(async move { + tracing::info!("Web dashboard starting on http://0.0.0.0:{}", web_port); + let app = web::create_router(web_db); + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", web_port)) + .await + .expect("Failed to bind web dashboard port"); + if let Err(e) = axum::serve(listener, app).await { + tracing::error!("Web dashboard error: {}", e); + } + }); + } + tracing::info!("btest-server-pro starting on port {}", cli.port); let v4 = if cli.listen_addr.eq_ignore_ascii_case("none") { None } else { Some(cli.listen_addr) }; diff --git a/src/server_pro/quota.rs b/src/server_pro/quota.rs index c4c9557..08b7e0f 100644 --- a/src/server_pro/quota.rs +++ b/src/server_pro/quota.rs @@ -1,6 +1,8 @@ //! Bandwidth quota management for btest-server-pro. //! -//! Enforces per-user and per-IP bandwidth limits (daily/weekly/monthly). +//! Enforces per-user and per-IP bandwidth limits (daily/weekly/monthly), +//! with separate tracking for inbound (client-to-server) and outbound +//! (server-to-client) directions. use std::collections::HashMap; use std::net::IpAddr; @@ -8,6 +10,19 @@ use std::sync::{Arc, Mutex}; use super::user_db::UserDb; +/// Traffic direction for bandwidth tests. +/// +/// From the **server's** perspective: +/// - `Inbound` = client sends data to us (client TX, server RX) +/// - `Outbound` = we send data to the client (server TX, client RX) +/// - `Both` = bidirectional test +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Direction { + Inbound, + Outbound, + Both, +} + #[derive(Clone)] pub struct QuotaManager { db: UserDb, @@ -15,10 +30,17 @@ pub struct QuotaManager { default_daily: u64, default_weekly: u64, default_monthly: u64, - /// Per-IP limits (0 = unlimited) — for abuse prevention + /// Per-IP combined (inbound + outbound) limits (0 = unlimited) — for abuse prevention ip_daily: u64, ip_weekly: u64, ip_monthly: u64, + /// Per-IP directional limits (0 = unlimited) + ip_daily_inbound: u64, + ip_daily_outbound: u64, + ip_weekly_inbound: u64, + ip_weekly_outbound: u64, + ip_monthly_inbound: u64, + ip_monthly_outbound: u64, /// Max simultaneous connections from one IP max_conn_per_ip: u32, /// Max test duration in seconds @@ -31,9 +53,21 @@ pub enum QuotaError { DailyExceeded { used: u64, limit: u64 }, WeeklyExceeded { used: u64, limit: u64 }, MonthlyExceeded { used: u64, limit: u64 }, + /// Combined (inbound + outbound) IP daily limit exceeded. IpDailyExceeded { used: u64, limit: u64 }, + /// Combined (inbound + outbound) IP weekly limit exceeded. IpWeeklyExceeded { used: u64, limit: u64 }, + /// Combined (inbound + outbound) IP monthly limit exceeded. IpMonthlyExceeded { used: u64, limit: u64 }, + /// Per-direction IP daily limits. + IpInboundDailyExceeded { used: u64, limit: u64 }, + IpOutboundDailyExceeded { used: u64, limit: u64 }, + /// Per-direction IP weekly limits. + IpInboundWeeklyExceeded { used: u64, limit: u64 }, + IpOutboundWeeklyExceeded { used: u64, limit: u64 }, + /// Per-direction IP monthly limits. + IpInboundMonthlyExceeded { used: u64, limit: u64 }, + IpOutboundMonthlyExceeded { used: u64, limit: u64 }, TooManyConnections { current: u32, limit: u32 }, UserDisabled, UserNotFound, @@ -54,6 +88,18 @@ impl std::fmt::Display for QuotaError { write!(f, "IP weekly quota exceeded: {}/{} bytes", used, limit), Self::IpMonthlyExceeded { used, limit } => write!(f, "IP monthly quota exceeded: {}/{} bytes", used, limit), + Self::IpInboundDailyExceeded { used, limit } => + write!(f, "IP inbound daily quota exceeded: {}/{} bytes", used, limit), + Self::IpOutboundDailyExceeded { used, limit } => + write!(f, "IP outbound daily quota exceeded: {}/{} bytes", used, limit), + Self::IpInboundWeeklyExceeded { used, limit } => + write!(f, "IP inbound weekly quota exceeded: {}/{} bytes", used, limit), + Self::IpOutboundWeeklyExceeded { used, limit } => + write!(f, "IP outbound weekly quota exceeded: {}/{} bytes", used, limit), + Self::IpInboundMonthlyExceeded { used, limit } => + write!(f, "IP inbound monthly quota exceeded: {}/{} bytes", used, limit), + Self::IpOutboundMonthlyExceeded { used, limit } => + write!(f, "IP outbound monthly quota exceeded: {}/{} bytes", used, limit), Self::TooManyConnections { current, limit } => write!(f, "Too many connections from this IP: {}/{}", current, limit), Self::UserDisabled => write!(f, "User account is disabled"), @@ -63,6 +109,7 @@ impl std::fmt::Display for QuotaError { } impl QuotaManager { + #[allow(clippy::too_many_arguments)] pub fn new( db: UserDb, default_daily: u64, @@ -71,6 +118,12 @@ impl QuotaManager { ip_daily: u64, ip_weekly: u64, ip_monthly: u64, + ip_daily_inbound: u64, + ip_daily_outbound: u64, + ip_weekly_inbound: u64, + ip_weekly_outbound: u64, + ip_monthly_inbound: u64, + ip_monthly_outbound: u64, max_conn_per_ip: u32, max_duration: u64, ) -> Self { @@ -82,6 +135,12 @@ impl QuotaManager { ip_daily, ip_weekly, ip_monthly, + ip_daily_inbound, + ip_daily_outbound, + ip_weekly_inbound, + ip_weekly_outbound, + ip_monthly_inbound, + ip_monthly_outbound, max_conn_per_ip, max_duration, active_connections: Arc::new(Mutex::new(HashMap::new())), @@ -130,8 +189,14 @@ impl QuotaManager { Ok(()) } - /// Check if an IP is allowed to connect (connection count + bandwidth quotas). - pub fn check_ip(&self, ip: &IpAddr) -> Result<(), QuotaError> { + /// Check if an IP is allowed to connect, considering both combined and + /// directional bandwidth quotas. + /// + /// The `direction` parameter indicates which direction the test will use. + /// For `Direction::Both`, both inbound and outbound directional limits are + /// checked. Combined (total) limits are always checked regardless of + /// direction. + pub fn check_ip(&self, ip: &IpAddr, direction: Direction) -> Result<(), QuotaError> { // Connection limit if self.max_conn_per_ip > 0 { let conns = self.active_connections.lock().unwrap(); @@ -146,27 +211,46 @@ impl QuotaManager { let ip_str = ip.to_string(); - // IP daily + // --- Combined (inbound + outbound) limits --- + self.check_ip_combined(&ip_str)?; + + // --- Directional limits --- + let check_inbound = matches!(direction, Direction::Inbound | Direction::Both); + let check_outbound = matches!(direction, Direction::Outbound | Direction::Both); + + if check_inbound { + self.check_ip_inbound(&ip_str)?; + } + if check_outbound { + self.check_ip_outbound(&ip_str)?; + } + + Ok(()) + } + + /// Check combined (total inbound + outbound) IP limits. + fn check_ip_combined(&self, ip_str: &str) -> Result<(), QuotaError> { + // IP daily (combined) if self.ip_daily > 0 { - let (tx, rx) = self.db.get_ip_daily_usage(&ip_str).unwrap_or((0, 0)); + let (tx, rx) = self.db.get_ip_daily_usage(ip_str).unwrap_or((0, 0)); let used = tx + rx; if used >= self.ip_daily { return Err(QuotaError::IpDailyExceeded { used, limit: self.ip_daily }); } } - // IP weekly + // IP weekly (combined) if self.ip_weekly > 0 { - let (tx, rx) = self.db.get_ip_weekly_usage(&ip_str).unwrap_or((0, 0)); + let (tx, rx) = self.db.get_ip_weekly_usage(ip_str).unwrap_or((0, 0)); let used = tx + rx; if used >= self.ip_weekly { return Err(QuotaError::IpWeeklyExceeded { used, limit: self.ip_weekly }); } } - // IP monthly + // IP monthly (combined) if self.ip_monthly > 0 { - let (tx, rx) = self.db.get_ip_monthly_usage(&ip_str).unwrap_or((0, 0)); + let (tx, rx) = self.db.get_ip_monthly_usage(ip_str).unwrap_or((0, 0)); let used = tx + rx; if used >= self.ip_monthly { return Err(QuotaError::IpMonthlyExceeded { used, limit: self.ip_monthly }); @@ -176,6 +260,82 @@ impl QuotaManager { Ok(()) } + /// Check inbound-only (client sends to us) IP limits. + fn check_ip_inbound(&self, ip_str: &str) -> Result<(), QuotaError> { + // Daily inbound + if self.ip_daily_inbound > 0 { + let used = self.db.get_ip_daily_inbound(ip_str).unwrap_or(0); + if used >= self.ip_daily_inbound { + return Err(QuotaError::IpInboundDailyExceeded { + used, + limit: self.ip_daily_inbound, + }); + } + } + + // Weekly inbound + if self.ip_weekly_inbound > 0 { + let used = self.db.get_ip_weekly_inbound(ip_str).unwrap_or(0); + if used >= self.ip_weekly_inbound { + return Err(QuotaError::IpInboundWeeklyExceeded { + used, + limit: self.ip_weekly_inbound, + }); + } + } + + // Monthly inbound + if self.ip_monthly_inbound > 0 { + let used = self.db.get_ip_monthly_inbound(ip_str).unwrap_or(0); + if used >= self.ip_monthly_inbound { + return Err(QuotaError::IpInboundMonthlyExceeded { + used, + limit: self.ip_monthly_inbound, + }); + } + } + + Ok(()) + } + + /// Check outbound-only (we send to client) IP limits. + fn check_ip_outbound(&self, ip_str: &str) -> Result<(), QuotaError> { + // Daily outbound + if self.ip_daily_outbound > 0 { + let used = self.db.get_ip_daily_outbound(ip_str).unwrap_or(0); + if used >= self.ip_daily_outbound { + return Err(QuotaError::IpOutboundDailyExceeded { + used, + limit: self.ip_daily_outbound, + }); + } + } + + // Weekly outbound + if self.ip_weekly_outbound > 0 { + let used = self.db.get_ip_weekly_outbound(ip_str).unwrap_or(0); + if used >= self.ip_weekly_outbound { + return Err(QuotaError::IpOutboundWeeklyExceeded { + used, + limit: self.ip_weekly_outbound, + }); + } + } + + // Monthly outbound + if self.ip_monthly_outbound > 0 { + let used = self.db.get_ip_monthly_outbound(ip_str).unwrap_or(0); + if used >= self.ip_monthly_outbound { + return Err(QuotaError::IpOutboundMonthlyExceeded { + used, + limit: self.ip_monthly_outbound, + }); + } + } + + Ok(()) + } + pub fn connect(&self, ip: &IpAddr) { let mut conns = self.active_connections.lock().unwrap(); *conns.entry(*ip).or_insert(0) += 1; @@ -191,14 +351,38 @@ impl QuotaManager { } } - /// Record usage after a test completes (both user and IP). - pub fn record_usage(&self, username: &str, ip: &str, tx_bytes: u64, rx_bytes: u64) { - if let Err(e) = self.db.record_usage(username, tx_bytes, rx_bytes) { + /// Record usage after a test completes (both user and IP), with separate + /// inbound and outbound byte counts. + /// + /// - `inbound_bytes`: bytes the client sent to us (server RX). + /// - `outbound_bytes`: bytes we sent to the client (server TX). + /// + /// Both the combined user/IP usage and directional IP usage are recorded. + pub fn record_usage( + &self, + username: &str, + ip: &str, + inbound_bytes: u64, + outbound_bytes: u64, + ) { + // Record combined user usage (tx/rx from the server's perspective: + // tx = outbound, rx = inbound). + if let Err(e) = self.db.record_usage(username, outbound_bytes, inbound_bytes) { tracing::error!("Failed to record user usage for {}: {}", username, e); } - if let Err(e) = self.db.record_ip_usage(ip, tx_bytes, rx_bytes) { + + // Record combined IP usage. + if let Err(e) = self.db.record_ip_usage(ip, outbound_bytes, inbound_bytes) { tracing::error!("Failed to record IP usage for {}: {}", ip, e); } + + // Record directional IP usage for the new per-direction columns. + if let Err(e) = self.db.record_ip_inbound_usage(ip, inbound_bytes) { + tracing::error!("Failed to record IP inbound usage for {}: {}", ip, e); + } + if let Err(e) = self.db.record_ip_outbound_usage(ip, outbound_bytes) { + tracing::error!("Failed to record IP outbound usage for {}: {}", ip, e); + } } pub fn max_duration(&self) -> u64 { diff --git a/src/server_pro/server_loop.rs b/src/server_pro/server_loop.rs index 3823b83..6a48076 100644 --- a/src/server_pro/server_loop.rs +++ b/src/server_pro/server_loop.rs @@ -15,7 +15,7 @@ use btest_rs::protocol::*; use btest_rs::bandwidth::BandwidthState; use super::enforcer::{QuotaEnforcer, StopReason}; -use super::quota::QuotaManager; +use super::quota::{Direction, QuotaManager}; use super::user_db::UserDb; /// Run the pro server with quota enforcement. @@ -70,7 +70,7 @@ pub async fn run_pro_server( tracing::info!("New connection from {}", peer); // Pre-connection IP check - if let Err(e) = quota_mgr.check_ip(&peer.ip()) { + if let Err(e) = quota_mgr.check_ip(&peer.ip(), Direction::Both) { tracing::warn!("Rejected {} — {}", peer, e); btest_rs::syslog_logger::auth_failure( &peer.to_string(), "-", "-", &format!("{}", e), diff --git a/src/server_pro/user_db.rs b/src/server_pro/user_db.rs index 9462ef5..cc3f731 100644 --- a/src/server_pro/user_db.rs +++ b/src/server_pro/user_db.rs @@ -29,6 +29,39 @@ pub struct UsageRecord { pub test_count: u32, } +/// Per-second bandwidth interval data for graphing. +#[derive(Debug, Clone)] +pub struct IntervalData { + pub interval_num: i32, + pub tx_mbps: f64, + pub rx_mbps: f64, + pub local_cpu: i32, + pub remote_cpu: i32, + pub lost: i64, +} + +/// Summary of a single test session. +#[derive(Debug, Clone)] +pub struct SessionSummary { + pub id: i64, + pub started_at: String, + pub ended_at: Option, + pub protocol: String, + pub direction: String, + pub tx_bytes: u64, + pub rx_bytes: u64, +} + +/// Aggregate statistics for an IP address. +#[derive(Debug, Clone)] +pub struct IpStats { + pub total_tests: u64, + pub total_inbound: u64, + pub total_outbound: u64, + pub avg_tx_mbps: f64, + pub avg_rx_mbps: f64, +} + impl UserDb { pub fn open(path: &str) -> anyhow::Result { let conn = Connection::open(path)?; @@ -65,8 +98,8 @@ impl UserDb { id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, date TEXT NOT NULL, - tx_bytes INTEGER DEFAULT 0, - rx_bytes INTEGER DEFAULT 0, + inbound_bytes INTEGER DEFAULT 0, + outbound_bytes INTEGER DEFAULT 0, test_count INTEGER DEFAULT 0, UNIQUE(ip, date) ); @@ -83,9 +116,24 @@ impl UserDb { direction TEXT ); + CREATE TABLE IF NOT EXISTS test_intervals ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + interval_num INTEGER NOT NULL, + tx_bytes INTEGER DEFAULT 0, + rx_bytes INTEGER DEFAULT 0, + tx_mbps REAL DEFAULT 0, + rx_mbps REAL DEFAULT 0, + local_cpu INTEGER DEFAULT 0, + remote_cpu INTEGER DEFAULT 0, + lost_packets INTEGER DEFAULT 0, + FOREIGN KEY(session_id) REFERENCES sessions(id) + ); + CREATE INDEX IF NOT EXISTS idx_usage_user_date ON usage(username, date); CREATE INDEX IF NOT EXISTS idx_ip_usage_date ON ip_usage(ip, date); CREATE INDEX IF NOT EXISTS idx_sessions_peer ON sessions(peer_ip, started_at); + CREATE INDEX IF NOT EXISTS idx_intervals_session ON test_intervals(session_id); ")?; Ok(()) } @@ -197,14 +245,18 @@ impl UserDb { pub fn record_ip_usage(&self, ip: &str, tx_bytes: u64, rx_bytes: u64) -> anyhow::Result<()> { let conn = self.conn.lock().unwrap(); let today = chrono_date_today(); + // From the server's perspective: inbound = data coming FROM the client (rx), + // outbound = data going TO the client (tx). + let inbound = rx_bytes; + let outbound = tx_bytes; conn.execute( - "INSERT INTO ip_usage (ip, date, tx_bytes, rx_bytes, test_count) + "INSERT INTO ip_usage (ip, date, inbound_bytes, outbound_bytes, test_count) VALUES (?1, ?2, ?3, ?4, 1) ON CONFLICT(ip, date) DO UPDATE SET - tx_bytes = tx_bytes + ?3, - rx_bytes = rx_bytes + ?4, + inbound_bytes = inbound_bytes + ?3, + outbound_bytes = outbound_bytes + ?4, test_count = test_count + 1", - params![ip, today, tx_bytes as i64, rx_bytes as i64], + params![ip, today, inbound as i64, outbound as i64], )?; Ok(()) } @@ -213,12 +265,12 @@ impl UserDb { let conn = self.conn.lock().unwrap(); let today = chrono_date_today(); let result = conn.query_row( - "SELECT COALESCE(SUM(tx_bytes),0), COALESCE(SUM(rx_bytes),0) FROM ip_usage WHERE ip = ?1 AND date = ?2", + "SELECT COALESCE(SUM(inbound_bytes),0), COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date = ?2", params![ip, today], |row| { - let a: i64 = row.get(0)?; - let b: i64 = row.get(1)?; - Ok((a as u64, b as u64)) + let inbound: i64 = row.get(0)?; + let outbound: i64 = row.get(1)?; + Ok((inbound as u64, outbound as u64)) }, )?; Ok(result) @@ -227,13 +279,13 @@ impl UserDb { pub fn get_ip_weekly_usage(&self, ip: &str) -> anyhow::Result<(u64, u64)> { let conn = self.conn.lock().unwrap(); let result = conn.query_row( - "SELECT COALESCE(SUM(tx_bytes),0), COALESCE(SUM(rx_bytes),0) FROM ip_usage + "SELECT COALESCE(SUM(inbound_bytes),0), COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-7 days')", params![ip], |row| { - let a: i64 = row.get(0)?; - let b: i64 = row.get(1)?; - Ok((a as u64, b as u64)) + let inbound: i64 = row.get(0)?; + let outbound: i64 = row.get(1)?; + Ok((inbound as u64, outbound as u64)) }, )?; Ok(result) @@ -242,18 +294,116 @@ impl UserDb { pub fn get_ip_monthly_usage(&self, ip: &str) -> anyhow::Result<(u64, u64)> { let conn = self.conn.lock().unwrap(); let result = conn.query_row( - "SELECT COALESCE(SUM(tx_bytes),0), COALESCE(SUM(rx_bytes),0) FROM ip_usage + "SELECT COALESCE(SUM(inbound_bytes),0), COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-30 days')", params![ip], |row| { - let a: i64 = row.get(0)?; - let b: i64 = row.get(1)?; - Ok((a as u64, b as u64)) + let inbound: i64 = row.get(0)?; + let outbound: i64 = row.get(1)?; + Ok((inbound as u64, outbound as u64)) }, )?; Ok(result) } + // --- Per-IP directional usage (single-column queries) --- + + /// Record inbound-only IP usage (data coming FROM the client). + pub fn record_ip_inbound_usage(&self, ip: &str, bytes: u64) -> anyhow::Result<()> { + let conn = self.conn.lock().unwrap(); + let today = chrono_date_today(); + conn.execute( + "INSERT INTO ip_usage (ip, date, inbound_bytes, test_count) + VALUES (?1, ?2, ?3, 0) + ON CONFLICT(ip, date) DO UPDATE SET + inbound_bytes = inbound_bytes + ?3", + params![ip, today, bytes as i64], + )?; + Ok(()) + } + + /// Record outbound-only IP usage (data going TO the client). + pub fn record_ip_outbound_usage(&self, ip: &str, bytes: u64) -> anyhow::Result<()> { + let conn = self.conn.lock().unwrap(); + let today = chrono_date_today(); + conn.execute( + "INSERT INTO ip_usage (ip, date, outbound_bytes, test_count) + VALUES (?1, ?2, ?3, 0) + ON CONFLICT(ip, date) DO UPDATE SET + outbound_bytes = outbound_bytes + ?3", + params![ip, today, bytes as i64], + )?; + Ok(()) + } + + /// Get daily inbound bytes for an IP. + pub fn get_ip_daily_inbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let today = chrono_date_today(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(inbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date = ?2", + params![ip, today], + |row| row.get(0), + )?; + Ok(result as u64) + } + + /// Get weekly inbound bytes for an IP. + pub fn get_ip_weekly_inbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(inbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-7 days')", + params![ip], + |row| row.get(0), + )?; + Ok(result as u64) + } + + /// Get monthly inbound bytes for an IP. + pub fn get_ip_monthly_inbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(inbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-30 days')", + params![ip], + |row| row.get(0), + )?; + Ok(result as u64) + } + + /// Get daily outbound bytes for an IP. + pub fn get_ip_daily_outbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let today = chrono_date_today(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date = ?2", + params![ip, today], + |row| row.get(0), + )?; + Ok(result as u64) + } + + /// Get weekly outbound bytes for an IP. + pub fn get_ip_weekly_outbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-7 days')", + params![ip], + |row| row.get(0), + )?; + Ok(result as u64) + } + + /// Get monthly outbound bytes for an IP. + pub fn get_ip_monthly_outbound(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let result: i64 = conn.query_row( + "SELECT COALESCE(SUM(outbound_bytes),0) FROM ip_usage WHERE ip = ?1 AND date >= date('now', '-30 days')", + params![ip], + |row| row.get(0), + )?; + Ok(result as u64) + } + // --- Session tracking --- pub fn start_session(&self, username: &str, peer_ip: &str, protocol: &str, direction: &str) -> anyhow::Result { @@ -274,6 +424,125 @@ impl UserDb { Ok(()) } + // --- Per-second interval tracking --- + + /// Record a single per-second interval data point for a session. + #[allow(clippy::too_many_arguments)] + pub fn record_test_interval( + &self, + session_id: i64, + interval_num: i32, + tx_bytes: u64, + rx_bytes: u64, + tx_mbps: f64, + rx_mbps: f64, + local_cpu: i32, + remote_cpu: i32, + lost: i64, + ) -> anyhow::Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute( + "INSERT INTO test_intervals (session_id, interval_num, tx_bytes, rx_bytes, tx_mbps, rx_mbps, local_cpu, remote_cpu, lost_packets) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + params![ + session_id, + interval_num, + tx_bytes as i64, + rx_bytes as i64, + tx_mbps, + rx_mbps, + local_cpu, + remote_cpu, + lost, + ], + )?; + Ok(()) + } + + /// Retrieve all interval data points for a given session, ordered by interval number. + pub fn get_session_intervals(&self, session_id: i64) -> anyhow::Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT interval_num, tx_mbps, rx_mbps, local_cpu, remote_cpu, lost_packets + FROM test_intervals WHERE session_id = ?1 ORDER BY interval_num" + )?; + let rows = stmt.query_map(params![session_id], |row| { + Ok(IntervalData { + interval_num: row.get(0)?, + tx_mbps: row.get(1)?, + rx_mbps: row.get(2)?, + local_cpu: row.get(3)?, + remote_cpu: row.get(4)?, + lost: row.get(5)?, + }) + })?.filter_map(|r| r.ok()).collect(); + Ok(rows) + } + + /// Return the last N sessions for a given IP address, most recent first. + pub fn get_ip_sessions(&self, ip: &str, limit: u32) -> anyhow::Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, started_at, ended_at, protocol, direction, tx_bytes, rx_bytes + FROM sessions WHERE peer_ip = ?1 ORDER BY started_at DESC LIMIT ?2" + )?; + let rows = stmt.query_map(params![ip, limit], |row| { + Ok(SessionSummary { + id: row.get(0)?, + started_at: row.get(1)?, + ended_at: row.get(2)?, + protocol: row.get::<_, Option>(3)?.unwrap_or_default(), + direction: row.get::<_, Option>(4)?.unwrap_or_default(), + tx_bytes: row.get::<_, i64>(5).map(|v| v as u64)?, + rx_bytes: row.get::<_, i64>(6).map(|v| v as u64)?, + }) + })?.filter_map(|r| r.ok()).collect(); + Ok(rows) + } + + /// Return aggregate statistics for an IP address across all sessions. + pub fn get_ip_stats(&self, ip: &str) -> anyhow::Result { + let conn = self.conn.lock().unwrap(); + let result = conn.query_row( + "SELECT + COUNT(*) as total_tests, + COALESCE(SUM(inbound_bytes), 0) as total_inbound, + COALESCE(SUM(outbound_bytes), 0) as total_outbound + FROM ip_usage WHERE ip = ?1", + params![ip], + |row| { + let total_tests: i64 = row.get(0)?; + let total_inbound: i64 = row.get(1)?; + let total_outbound: i64 = row.get(2)?; + Ok((total_tests as u64, total_inbound as u64, total_outbound as u64)) + }, + )?; + + // Compute average Mbps from test_intervals joined through sessions + let (avg_tx, avg_rx) = conn.query_row( + "SELECT + COALESCE(AVG(ti.tx_mbps), 0.0), + COALESCE(AVG(ti.rx_mbps), 0.0) + FROM test_intervals ti + INNER JOIN sessions s ON ti.session_id = s.id + WHERE s.peer_ip = ?1", + params![ip], + |row| { + let avg_tx: f64 = row.get(0)?; + let avg_rx: f64 = row.get(1)?; + Ok((avg_tx, avg_rx)) + }, + )?; + + Ok(IpStats { + total_tests: result.0, + total_inbound: result.1, + total_outbound: result.2, + avg_tx_mbps: avg_tx, + avg_rx_mbps: avg_rx, + }) + } + pub fn delete_user(&self, username: &str) -> anyhow::Result { let conn = self.conn.lock().unwrap(); let rows = conn.execute("DELETE FROM users WHERE username = ?1", params![username])?; diff --git a/src/server_pro/web/mod.rs b/src/server_pro/web/mod.rs new file mode 100644 index 0000000..b7206c7 --- /dev/null +++ b/src/server_pro/web/mod.rs @@ -0,0 +1,546 @@ +//! Web dashboard module for btest-server-pro. +//! +//! Provides an axum-based HTTP dashboard with: +//! - Landing page with IP lookup +//! - Per-IP session history and statistics +//! - Chart.js throughput graphs +//! +//! # Feature gate +//! +//! This entire module is compiled only when the `pro` feature is active +//! (it lives inside the `btest-server-pro` binary crate which already +//! requires `--features pro`). +//! +//! # Template files +//! +//! The HTML source lives in `src/server_pro/web/templates/` as standalone +//! `.html` files for easy editing. The Rust code embeds them via the askama +//! `source` attribute so no `askama.toml` configuration is needed. If you +//! prefer external template files, create `askama.toml` at the crate root: +//! +//! ```toml +//! [[dirs]] +//! path = "src/server_pro/web/templates" +//! ``` +//! +//! Then change `source = "..."` to `path = "index.html"` (etc.) in the +//! template structs below. + +use std::sync::Arc; + +use askama::Template; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::response::{Html, IntoResponse, Response}; +use axum::routing::get; +use axum::Router; +use rusqlite::{params, Connection}; +use serde::Serialize; + +use super::user_db::UserDb; + +// --------------------------------------------------------------------------- +// Shared state +// --------------------------------------------------------------------------- + +/// Shared application state passed to all handlers via axum's `State`. +pub struct WebState { + /// Reference to the main user/session database. + pub db: UserDb, + /// Separate read-only connection for dashboard queries that are not + /// exposed by [`UserDb`] (e.g. listing sessions, aggregate stats). + /// Wrapped in a [`std::sync::Mutex`] because [`rusqlite::Connection`] + /// is not `Send + Sync` on its own. + pub query_conn: std::sync::Mutex, +} + +// --------------------------------------------------------------------------- +// Router constructor +// --------------------------------------------------------------------------- + +/// Default database filename used when `BTEST_DB_PATH` is not set. +const DEFAULT_DB_PATH: &str = "btest-users.db"; + +/// Build the axum [`Router`] for the web dashboard. +/// +/// The database path for the read-only query connection is resolved in the +/// following order: +/// +/// 1. The `BTEST_DB_PATH` environment variable (if set). +/// 2. The compile-time default `btest-users.db`. +/// +/// # Panics +/// +/// Panics if the read-only database connection or the DDL for the +/// `session_intervals` table cannot be established. This is intentional: +/// the web module is optional and failure during startup should surface +/// loudly rather than silently serving broken pages. +pub fn create_router(db: UserDb) -> Router { + let db_path = std::env::var("BTEST_DB_PATH").unwrap_or_else(|_| DEFAULT_DB_PATH.to_string()); + + let query_conn = Connection::open_with_flags( + &db_path, + rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY + | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX, + ) + .expect("web: failed to open read-only database connection"); + query_conn + .execute_batch("PRAGMA busy_timeout=5000;") + .expect("web: failed to set PRAGMA on query connection"); + + // Ensure the `session_intervals` table exists. The server loop must + // INSERT rows for the chart to have data; the table is created here so + // the schema is ready. + ensure_web_tables(&db_path).expect("web: failed to create session_intervals table"); + + let state = Arc::new(WebState { + db, + query_conn: std::sync::Mutex::new(query_conn), + }); + + // axum 0.8 uses `{param}` syntax for path parameters. + Router::new() + .route("/", get(index_page)) + .route("/dashboard/{ip}", get(dashboard_page)) + .route("/api/ip/{ip}/sessions", get(api_sessions)) + .route("/api/ip/{ip}/stats", get(api_stats)) + .route("/api/session/{id}/intervals", get(api_intervals)) + .with_state(state) +} + +/// Create additional tables the web dashboard depends on. +/// +/// Opens a short-lived writable connection solely for DDL so it does not +/// interfere with the main [`UserDb`] connection. +fn ensure_web_tables(db_path: &str) -> anyhow::Result<()> { + let conn = Connection::open(db_path)?; + conn.execute_batch("PRAGMA busy_timeout=5000;")?; + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS session_intervals ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + second INTEGER NOT NULL, + tx_bytes INTEGER NOT NULL DEFAULT 0, + rx_bytes INTEGER NOT NULL DEFAULT 0, + UNIQUE(session_id, second) + ); + CREATE INDEX IF NOT EXISTS idx_intervals_session + ON session_intervals(session_id, second);", + )?; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Askama templates (embedded via `source`) +// --------------------------------------------------------------------------- + +/// Landing / index page template. +#[derive(Template)] +#[template( + source = r##" + + + + +btest-rs Public Bandwidth Test Server + + + +
+

btest-rs

+

Public MikroTik Bandwidth Test Server — view your test results and history.

+ + +
+

How it works

+

Run a bandwidth test from your MikroTik router targeting this server. + After the test completes, enter your public IP above to see + throughput charts, session history, and aggregate statistics.

+

+ Example: /tool bandwidth-test address=this-server protocol=tcp direction=both +

+
+ +
+ + +"##, + ext = "html" +)] +struct IndexTemplate; + +/// Per-IP dashboard page template. +#[derive(Template)] +#[template( + source = r##" + + + + +Dashboard — {{ ip }} — btest-rs + + + +
+

btest-rs

+ {{ ip }} + Home +
+
+
Total Tests
+
Total TX
+
Total RX
+
Avg TX Mbps
+
Avg RX Mbps
+
+
+

Select a test below to view its throughput chart

+
+ +
Click a row in the table to load the throughput graph for that session.
+
+
+
+ + + +
#DateProtocolDirectionTX BytesRX BytesDurationAvg TX MbpsAvg RX Mbps
Loading sessions...
+
+ + + + +"##, + ext = "html" +)] +struct DashboardTemplate { + ip: String, +} + +// --------------------------------------------------------------------------- +// JSON response types +// --------------------------------------------------------------------------- + +/// A single test session as returned by the sessions API. +#[derive(Serialize)] +struct SessionJson { + id: i64, + username: String, + peer_ip: String, + started_at: Option, + ended_at: Option, + tx_bytes: i64, + rx_bytes: i64, + protocol: Option, + direction: Option, +} + +/// Aggregate statistics for an IP address. +#[derive(Serialize)] +struct StatsJson { + total_sessions: i64, + total_tx_bytes: i64, + total_rx_bytes: i64, + avg_tx_mbps: f64, + avg_rx_mbps: f64, +} + +/// One second of throughput data within a session. +#[derive(Serialize)] +struct IntervalJson { + second: i64, + tx_bytes: i64, + rx_bytes: i64, +} + +// --------------------------------------------------------------------------- +// Error helper +// --------------------------------------------------------------------------- + +/// Uniform error wrapper so handlers can use `?` freely. +/// +/// All errors are rendered as `500 Internal Server Error` with a plain-text +/// body. The full error chain is logged via [`tracing`]. +struct AppError(anyhow::Error); + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + tracing::error!("web handler error: {:#}", self.0); + (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } +} + +impl> From for AppError { + fn from(err: E) -> Self { + Self(err.into()) + } +} + +// --------------------------------------------------------------------------- +// Handlers +// --------------------------------------------------------------------------- + +/// `GET /` -- render the landing page. +async fn index_page() -> Result, AppError> { + let rendered = IndexTemplate + .render() + .map_err(|e| anyhow::anyhow!("template render: {}", e))?; + Ok(Html(rendered)) +} + +/// `GET /dashboard/{ip}` -- render the per-IP dashboard. +async fn dashboard_page(Path(ip): Path) -> Result, AppError> { + let rendered = DashboardTemplate { ip } + .render() + .map_err(|e| anyhow::anyhow!("template render: {}", e))?; + Ok(Html(rendered)) +} + +/// `GET /api/ip/{ip}/sessions` -- return the most recent 100 sessions for +/// the given peer IP as a JSON array. +async fn api_sessions( + State(state): State>, + Path(ip): Path, +) -> Result>, AppError> { + let sessions = { + let conn = state + .query_conn + .lock() + .map_err(|e| anyhow::anyhow!("lock: {}", e))?; + let mut stmt = conn.prepare( + "SELECT id, username, peer_ip, started_at, ended_at, + tx_bytes, rx_bytes, protocol, direction + FROM sessions + WHERE peer_ip = ?1 + ORDER BY started_at DESC + LIMIT 100", + )?; + let rows = stmt.query_map(params![ip], |row| { + Ok(SessionJson { + id: row.get(0)?, + username: row.get(1)?, + peer_ip: row.get(2)?, + started_at: row.get(3)?, + ended_at: row.get(4)?, + tx_bytes: row.get(5)?, + rx_bytes: row.get(6)?, + protocol: row.get(7)?, + direction: row.get(8)?, + }) + })?; + rows.filter_map(Result::ok).collect::>() + }; + + Ok(axum::Json(sessions)) +} + +/// `GET /api/ip/{ip}/stats` -- return aggregate statistics (total bytes, +/// session count, average throughput) for the given IP. +async fn api_stats( + State(state): State>, + Path(ip): Path, +) -> Result, AppError> { + let stats = { + let conn = state + .query_conn + .lock() + .map_err(|e| anyhow::anyhow!("lock: {}", e))?; + conn.query_row( + "SELECT + COUNT(*) AS total_sessions, + COALESCE(SUM(tx_bytes), 0) AS total_tx, + COALESCE(SUM(rx_bytes), 0) AS total_rx, + COALESCE(SUM( + CASE WHEN ended_at IS NOT NULL AND started_at IS NOT NULL + THEN (julianday(ended_at) - julianday(started_at)) * 86400.0 + ELSE 0 END + ), 0) AS total_seconds + FROM sessions + WHERE peer_ip = ?1", + params![ip], + |row| { + let total_sessions: i64 = row.get(0)?; + let total_tx: i64 = row.get(1)?; + let total_rx: i64 = row.get(2)?; + let total_seconds: f64 = row.get(3)?; + + let avg_tx_mbps = if total_seconds > 0.0 { + (total_tx as f64) * 8.0 / total_seconds / 1_000_000.0 + } else { + 0.0 + }; + let avg_rx_mbps = if total_seconds > 0.0 { + (total_rx as f64) * 8.0 / total_seconds / 1_000_000.0 + } else { + 0.0 + }; + + Ok(StatsJson { + total_sessions, + total_tx_bytes: total_tx, + total_rx_bytes: total_rx, + avg_tx_mbps, + avg_rx_mbps, + }) + }, + )? + }; + + Ok(axum::Json(stats)) +} + +/// `GET /api/session/{id}/intervals` -- return per-second throughput data +/// for a session. +/// +/// If the `session_intervals` table does not exist or contains no rows for +/// the requested session, an empty JSON array is returned. +async fn api_intervals( + State(state): State>, + Path(id): Path, +) -> Result>, AppError> { + let intervals = { + let conn = state + .query_conn + .lock() + .map_err(|e| anyhow::anyhow!("lock: {}", e))?; + + // Guard against the table not existing (e.g. first run before + // `ensure_web_tables` was ever called on this database file). + let table_exists: bool = conn + .query_row( + "SELECT COUNT(*) FROM sqlite_master \ + WHERE type = 'table' AND name = 'session_intervals'", + [], + |row| row.get::<_, i64>(0), + ) + .map(|c| c > 0) + .unwrap_or(false); + + if !table_exists { + Vec::new() + } else { + let mut stmt = conn.prepare( + "SELECT second, tx_bytes, rx_bytes + FROM session_intervals + WHERE session_id = ?1 + ORDER BY second ASC", + )?; + let rows = stmt.query_map(params![id], |row| { + Ok(IntervalJson { + second: row.get(0)?, + tx_bytes: row.get(1)?, + rx_bytes: row.get(2)?, + }) + })?; + rows.filter_map(Result::ok).collect::>() + } + }; + + Ok(axum::Json(intervals)) +} diff --git a/src/server_pro/web/templates/dashboard.html b/src/server_pro/web/templates/dashboard.html new file mode 100644 index 0000000..a1c02a5 --- /dev/null +++ b/src/server_pro/web/templates/dashboard.html @@ -0,0 +1,387 @@ + + + + + +Dashboard — {{ ip }} — btest-rs + + + + +
+

btest-rs

+ {{ ip }} + Home +
+ + +
+
+
Total Tests
+
+
+
+
Total TX
+
+
+
+
Total RX
+
+
+
+
Avg TX Mbps
+
+
+
+
Avg RX Mbps
+
+
+
+ + +
+

Select a test below to view its throughput chart

+
+ +
Click a row in the table to load the throughput graph for that session.
+
+
+ + +
+ + + + + + + + + + + + + + + + + +
#DateProtocolDirectionTX BytesRX BytesDurationAvg TX MbpsAvg RX Mbps
Loading sessions...
+
+ + + + + + + diff --git a/src/server_pro/web/templates/index.html b/src/server_pro/web/templates/index.html new file mode 100644 index 0000000..f736800 --- /dev/null +++ b/src/server_pro/web/templates/index.html @@ -0,0 +1,160 @@ + + + + + +btest-rs Public Bandwidth Test Server + + + +
+

btest-rs

+

Public MikroTik Bandwidth Test Server — view your test results and history.

+ + + + + +
+

How it works

+

+ Run a bandwidth test from your MikroTik router targeting this server. + After the test completes, enter your public IP above to see + throughput charts, session history, and aggregate statistics. +

+

+ Example: /tool bandwidth-test address=this-server protocol=tcp direction=both +

+
+ + +
+ + + +