diff --git a/README.md b/README.md index 9df8018..364463b 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ async fn main() -> Result<(), Box> { }); let mut ipstack_config = ipstack::IpStackConfig::default(); - ipstack_config.mtu(MTU); + ipstack_config.mtu(MTU)?; let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?); while let Ok(stream) = ip_stack.accept().await { diff --git a/examples/tun.rs b/examples/tun.rs index 5b2fad3..3ba8093 100644 --- a/examples/tun.rs +++ b/examples/tun.rs @@ -92,7 +92,7 @@ async fn main() -> Result<(), Box> { }); let mut ipstack_config = ipstack::IpStackConfig::default(); - ipstack_config.mtu(MTU); + ipstack_config.mtu(MTU)?; let mut tcp_config = ipstack::TcpConfig::default(); tcp_config.timeout = std::time::Duration::from_secs(args.tcp_timeout); tcp_config.options = Some(vec![ipstack::TcpOptions::MaximumSegmentSize(1460)]); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 9073f20..4d6649a 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box> { // }); let mut ipstack_config = ipstack::IpStackConfig::default(); - ipstack_config.mtu(MTU); + ipstack_config.mtu(MTU)?; // ipstack_config.packet_information(cfg!(target_family = "unix")); #[cfg(not(target_os = "windows"))] diff --git a/src/error.rs b/src/error.rs index 839b8b0..643f637 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,10 @@ pub enum IpStackError { /// Error sending data through a channel. #[error("Send Error {0}")] SendError(#[from] Box>), + + /// Invalid MTU size. The minimum MTU is 1280 bytes to comply with IPv6 standards. + #[error("Invalid MTU size: {0} (bytes). Minimum MTU is 1280 bytes.")] + InvalidMtuSize(u16), } impl From> for IpStackError { diff --git a/src/lib.rs b/src/lib.rs index a1d6e77..e67072a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,7 @@ const MIN_MTU: u16 = 1280; /// use std::time::Duration; /// /// let mut config = IpStackConfig::default(); -/// config.mtu(1500) +/// config.mtu(1500).expect("Failed to set MTU") /// .udp_timeout(Duration::from_secs(60)) /// .packet_information(false); /// ``` @@ -140,9 +140,18 @@ impl IpStackConfig { /// use ipstack::IpStackConfig; /// /// let mut config = IpStackConfig::default(); - /// config.mtu(1500); + /// config.mtu(1500).expect("Failed to set MTU"); /// ``` - pub fn mtu(&mut self, mtu: u16) -> &mut Self { + pub fn mtu(&mut self, mtu: u16) -> Result<&mut Self, IpStackError> { + if mtu < MIN_MTU { + return Err(IpStackError::InvalidMtuSize(mtu)); + } + self.mtu = mtu; + Ok(self) + } + + /// Set the Maximum Transmission Unit (MTU) size without validation. + pub fn mtu_unchecked(&mut self, mtu: u16) -> &mut Self { self.mtu = mtu; self } @@ -307,13 +316,6 @@ fn run( let mut buffer = vec![0_u8; config.mtu as usize + offset]; let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::(); - if config.mtu < MIN_MTU { - log::warn!( - "the MTU in the configuration ({}) below the MIN_MTU (1280) can cause problems.", - config.mtu - ); - } - tokio::spawn(async move { loop { select! { diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 93a484d..8917fbb 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,16 +1,16 @@ use super::seqnum::SeqNum; use etherparse::TcpHeader; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, time::Duration}; -const MAX_UNACK: u32 = 1024 * 16; // 16KB -const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB -const MAX_COUNT_FOR_DUP_ACK: usize = 3; // Maximum number of duplicate ACKs before retransmission +pub(super) const MAX_UNACK: u32 = 1024 * 16; // 16KB +pub(super) const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB +pub(super) const MAX_COUNT_FOR_DUP_ACK: usize = 3; // Maximum number of duplicate ACKs before retransmission /// Retransmission timeout -const RTO: std::time::Duration = std::time::Duration::from_secs(1); +pub(super) const RTO: std::time::Duration = std::time::Duration::from_secs(1); /// Maximum count of retransmissions before dropping the packet -pub(crate) const MAX_RETRANSMIT_COUNT: usize = 3; +pub(super) const MAX_RETRANSMIT_COUNT: usize = 3; #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum TcpState { @@ -55,10 +55,23 @@ pub(crate) struct Tcb { unordered_packets: BTreeMap>, duplicate_ack_count: usize, duplicate_ack_count_helper: SeqNum, + max_unacked_bytes: u32, + read_buffer_size: usize, + max_count_for_dup_ack: usize, + rto: std::time::Duration, + max_retransmit_count: usize, } impl Tcb { - pub(super) fn new(ack: SeqNum, mtu: u16) -> Tcb { + pub(super) fn new( + ack: SeqNum, + mtu: u16, + max_unacked_bytes: u32, + read_buffer_size: usize, + max_count_for_dup_ack: usize, + rto: std::time::Duration, + max_retransmit_count: usize, + ) -> Tcb { #[cfg(debug_assertions)] let seq = 100; #[cfg(not(debug_assertions))] @@ -74,6 +87,11 @@ impl Tcb { unordered_packets: BTreeMap::new(), duplicate_ack_count: 0, duplicate_ack_count_helper: seq.into(), + max_unacked_bytes, + read_buffer_size, + max_count_for_dup_ack, + rto, + max_retransmit_count, } } @@ -94,7 +112,7 @@ impl Tcb { } pub fn is_duplicate_ack_count_exceeded(&self) -> bool { - self.duplicate_ack_count >= MAX_COUNT_FOR_DUP_ACK + self.duplicate_ack_count >= self.max_count_for_dup_ack } pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec) { @@ -106,7 +124,7 @@ impl Tcb { self.unordered_packets.insert(seq, buf); } pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub(self.get_unordered_packets_total_len()) + self.read_buffer_size.saturating_sub(self.get_unordered_packets_total_len()) } #[inline] pub(crate) fn get_unordered_packets_total_len(&self) -> usize { @@ -234,7 +252,7 @@ impl Tcb { return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Empty payload")); } let buf_len = buf.len() as u32; - self.inflight_packets.insert(self.seq, InflightPacket::new(self.seq, buf)); + self.inflight_packets.insert(self.seq, InflightPacket::new(self.seq, buf, self.rto)); self.seq += buf_len; Ok(()) } @@ -275,7 +293,7 @@ impl Tcb { let mut retransmit_list = Vec::new(); self.inflight_packets.retain(|_, packet| { - if packet.retransmit_count >= MAX_RETRANSMIT_COUNT { + if packet.retransmit_count >= self.max_retransmit_count { log::warn!("Packet with seq {:?} reached max retransmit count, dropping packet", packet.seq); return false; // remove this packet } @@ -302,7 +320,7 @@ impl Tcb { pub fn is_send_buffer_full(&self) -> bool { // To respect the receiver's window (remote_window) size and avoid sending too many unacknowledged packets, which may cause packet loss // Simplified version: min(cwnd, rwnd) - self.seq.distance(self.get_last_received_ack()) >= MAX_UNACK.min(self.get_send_window() as u32) + self.seq.distance(self.get_last_received_ack()) >= self.max_unacked_bytes.min(self.get_send_window() as u32) } } @@ -316,13 +334,13 @@ pub struct InflightPacket { } impl InflightPacket { - fn new(seq: SeqNum, payload: Vec) -> Self { + fn new(seq: SeqNum, payload: Vec, rto: Duration) -> Self { Self { seq, payload, send_time: std::time::Instant::now(), retransmit_count: 0, - retransmit_timeout: RTO, + retransmit_timeout: rto, } } pub(crate) fn contains_seq_num(&self, seq: SeqNum) -> bool { @@ -339,7 +357,7 @@ mod tests { #[test] fn test_in_flight_packet() { - let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]); + let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50], RTO); assert!(p.contains_seq_num((u32::MAX - 1).into())); assert!(p.contains_seq_num(u32::MAX.into())); @@ -352,7 +370,15 @@ mod tests { #[test] fn test_get_unordered_packets_with_max_bytes() { - let mut tcb = Tcb::new(SeqNum(1000), 1500); + let mut tcb = Tcb::new( + SeqNum(1000), + 1500, + MAX_UNACK, + READ_BUFFER_SIZE, + MAX_COUNT_FOR_DUP_ACK, + RTO, + MAX_RETRANSMIT_COUNT, + ); // insert 3 consecutive packets tcb.add_unordered_packet(SeqNum(1000), vec![1; 500]); // seq=1000, len=500 @@ -384,7 +410,15 @@ mod tests { #[test] fn test_update_inflight_packet_queue() { - let mut tcb = Tcb::new(SeqNum(1000), 1500); + let mut tcb = Tcb::new( + SeqNum(1000), + 1500, + MAX_UNACK, + READ_BUFFER_SIZE, + MAX_COUNT_FOR_DUP_ACK, + RTO, + MAX_RETRANSMIT_COUNT, + ); tcb.seq = SeqNum(100); // setting the initial seq // insert 3 consecutive packets @@ -408,7 +442,15 @@ mod tests { #[test] fn test_update_inflight_packet_queue_cumulative_ack() { - let mut tcb = Tcb::new(SeqNum(1000), 1500); + let mut tcb = Tcb::new( + SeqNum(1000), + 1500, + MAX_UNACK, + READ_BUFFER_SIZE, + MAX_COUNT_FOR_DUP_ACK, + RTO, + MAX_RETRANSMIT_COUNT, + ); tcb.seq = SeqNum(1000); // Insert 3 consecutive packets @@ -423,7 +465,15 @@ mod tests { #[test] fn test_retransmit_with_exponential_backoff() { - let mut tcb = Tcb::new(SeqNum(1000), 1500); + let mut tcb = Tcb::new( + SeqNum(1000), + 1500, + MAX_UNACK, + READ_BUFFER_SIZE, + MAX_COUNT_FOR_DUP_ACK, + RTO, + MAX_RETRANSMIT_COUNT, + ); tcb.add_inflight_packet(vec![1; 500]).unwrap(); diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 267978f..60be851 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -7,7 +7,7 @@ use crate::{ tcp_flags::{ACK, FIN, PSH, RST, SYN}, tcp_header_flags, tcp_header_fmt, }, - stream::tcb::{PacketType, Tcb, TcpState}, + stream::tcb::{MAX_COUNT_FOR_DUP_ACK, MAX_RETRANSMIT_COUNT, MAX_UNACK, PacketType, READ_BUFFER_SIZE, RTO, Tcb, TcpState}, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader, TcpOptionElement}; use std::{ @@ -43,6 +43,16 @@ pub struct TcpConfig { pub timeout: Duration, /// Timeout for the TIME_WAIT state. Default is 2 seconds. pub two_msl: Duration, + /// Maximum number of unacknowledged bytes allowed in the send buffer. + pub max_unacked_bytes: u32, + /// Size of the read buffer for incoming data. + pub read_buffer_size: usize, + /// Maximum number of duplicate ACKs before triggering fast retransmission. + pub max_count_for_dup_ack: usize, + /// Retransmission timeout duration. + pub rto: std::time::Duration, + /// Maximum number of retransmissions before giving up. + pub max_retransmit_count: usize, /// TCP options pub options: Option>, } @@ -62,6 +72,11 @@ impl Default for TcpConfig { close_wait_timeout: CLOSE_WAIT_TIMEOUT, timeout: TIMEOUT, two_msl: TWO_MSL, + max_unacked_bytes: MAX_UNACK, + read_buffer_size: READ_BUFFER_SIZE, + max_count_for_dup_ack: MAX_COUNT_FOR_DUP_ACK, + rto: RTO, + max_retransmit_count: MAX_RETRANSMIT_COUNT, options: Default::default(), } } @@ -169,7 +184,15 @@ impl IpStackTcpStream { destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>, config: Arc, ) -> Result { - let tcb = Tcb::new(SeqNum(tcp.sequence_number), mtu); + let tcb = Tcb::new( + SeqNum(tcp.sequence_number), + mtu, + config.max_unacked_bytes, + config.read_buffer_size, + config.max_count_for_dup_ack, + config.rto, + config.max_retransmit_count, + ); let tuple = NetworkTuple::new(src_addr, dst_addr, true); if !tcp.syn { if !tcp.rst