Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(MTU);
ipstack_config.mtu(MTU)?;
let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?);

while let Ok(stream) = ip_stack.accept().await {
Expand Down
2 changes: 1 addition & 1 deletion examples/tun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(MTU);
ipstack_config.mtu(MTU)?;
let mut tcp_config = ipstack::TcpConfig::default();
tcp_config.timeout = std::time::Duration::from_secs(args.tcp_timeout);
tcp_config.options = Some(vec![ipstack::TcpOptions::MaximumSegmentSize(1460)]);
Expand Down
2 changes: 1 addition & 1 deletion examples/tun_wintun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// });

let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(MTU);
ipstack_config.mtu(MTU)?;
// ipstack_config.packet_information(cfg!(target_family = "unix"));

#[cfg(not(target_os = "windows"))]
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ pub enum IpStackError {
/// Error sending data through a channel.
#[error("Send Error {0}")]
SendError(#[from] Box<tokio::sync::mpsc::error::SendError<crate::stream::IpStackStream>>),

/// Invalid MTU size. The minimum MTU is 1280 bytes to comply with IPv6 standards.
#[error("Invalid MTU size: {0} (bytes). Minimum MTU is 1280 bytes.")]
InvalidMtuSize(u16),
}

impl From<tokio::sync::mpsc::error::SendError<crate::stream::IpStackStream>> for IpStackError {
Expand Down
22 changes: 12 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ const MIN_MTU: u16 = 1280;
/// use std::time::Duration;
///
/// let mut config = IpStackConfig::default();
/// config.mtu(1500)
/// config.mtu(1500).expect("Failed to set MTU")
/// .udp_timeout(Duration::from_secs(60))
/// .packet_information(false);
/// ```
Expand Down Expand Up @@ -140,9 +140,18 @@ impl IpStackConfig {
/// use ipstack::IpStackConfig;
///
/// let mut config = IpStackConfig::default();
/// config.mtu(1500);
/// config.mtu(1500).expect("Failed to set MTU");
/// ```
pub fn mtu(&mut self, mtu: u16) -> &mut Self {
pub fn mtu(&mut self, mtu: u16) -> Result<&mut Self, IpStackError> {
if mtu < MIN_MTU {
return Err(IpStackError::InvalidMtuSize(mtu));
}
self.mtu = mtu;
Ok(self)
}

/// Set the Maximum Transmission Unit (MTU) size without validation.
pub fn mtu_unchecked(&mut self, mtu: u16) -> &mut Self {
self.mtu = mtu;
self
}
Expand Down Expand Up @@ -307,13 +316,6 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
let mut buffer = vec![0_u8; config.mtu as usize + offset];
let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

if config.mtu < MIN_MTU {
log::warn!(
"the MTU in the configuration ({}) below the MIN_MTU (1280) can cause problems.",
config.mtu
);
}

tokio::spawn(async move {
loop {
select! {
Expand Down
88 changes: 69 additions & 19 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use super::seqnum::SeqNum;
use etherparse::TcpHeader;
use std::collections::BTreeMap;
use std::{collections::BTreeMap, time::Duration};

const MAX_UNACK: u32 = 1024 * 16; // 16KB
const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB
const MAX_COUNT_FOR_DUP_ACK: usize = 3; // Maximum number of duplicate ACKs before retransmission
pub(super) const MAX_UNACK: u32 = 1024 * 16; // 16KB
pub(super) const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB
pub(super) const MAX_COUNT_FOR_DUP_ACK: usize = 3; // Maximum number of duplicate ACKs before retransmission

/// Retransmission timeout
const RTO: std::time::Duration = std::time::Duration::from_secs(1);
pub(super) const RTO: std::time::Duration = std::time::Duration::from_secs(1);

/// Maximum count of retransmissions before dropping the packet
pub(crate) const MAX_RETRANSMIT_COUNT: usize = 3;
pub(super) const MAX_RETRANSMIT_COUNT: usize = 3;

#[derive(Debug, PartialEq, Clone, Copy)]
pub(crate) enum TcpState {
Expand Down Expand Up @@ -55,10 +55,23 @@ pub(crate) struct Tcb {
unordered_packets: BTreeMap<SeqNum, Vec<u8>>,
duplicate_ack_count: usize,
duplicate_ack_count_helper: SeqNum,
max_unacked_bytes: u32,
read_buffer_size: usize,
max_count_for_dup_ack: usize,
rto: std::time::Duration,
max_retransmit_count: usize,
}

impl Tcb {
pub(super) fn new(ack: SeqNum, mtu: u16) -> Tcb {
pub(super) fn new(
ack: SeqNum,
mtu: u16,
max_unacked_bytes: u32,
read_buffer_size: usize,
max_count_for_dup_ack: usize,
rto: std::time::Duration,
max_retransmit_count: usize,
) -> Tcb {
#[cfg(debug_assertions)]
let seq = 100;
#[cfg(not(debug_assertions))]
Expand All @@ -74,6 +87,11 @@ impl Tcb {
unordered_packets: BTreeMap::new(),
duplicate_ack_count: 0,
duplicate_ack_count_helper: seq.into(),
max_unacked_bytes,
read_buffer_size,
max_count_for_dup_ack,
rto,
max_retransmit_count,
}
}

Expand All @@ -94,7 +112,7 @@ impl Tcb {
}

pub fn is_duplicate_ack_count_exceeded(&self) -> bool {
self.duplicate_ack_count >= MAX_COUNT_FOR_DUP_ACK
self.duplicate_ack_count >= self.max_count_for_dup_ack
}

pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec<u8>) {
Expand All @@ -106,7 +124,7 @@ impl Tcb {
self.unordered_packets.insert(seq, buf);
}
pub(super) fn get_available_read_buffer_size(&self) -> usize {
READ_BUFFER_SIZE.saturating_sub(self.get_unordered_packets_total_len())
self.read_buffer_size.saturating_sub(self.get_unordered_packets_total_len())
}
#[inline]
pub(crate) fn get_unordered_packets_total_len(&self) -> usize {
Expand Down Expand Up @@ -234,7 +252,7 @@ impl Tcb {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Empty payload"));
}
let buf_len = buf.len() as u32;
self.inflight_packets.insert(self.seq, InflightPacket::new(self.seq, buf));
self.inflight_packets.insert(self.seq, InflightPacket::new(self.seq, buf, self.rto));
self.seq += buf_len;
Ok(())
}
Expand Down Expand Up @@ -275,7 +293,7 @@ impl Tcb {
let mut retransmit_list = Vec::new();

self.inflight_packets.retain(|_, packet| {
if packet.retransmit_count >= MAX_RETRANSMIT_COUNT {
if packet.retransmit_count >= self.max_retransmit_count {
log::warn!("Packet with seq {:?} reached max retransmit count, dropping packet", packet.seq);
return false; // remove this packet
}
Expand All @@ -302,7 +320,7 @@ impl Tcb {
pub fn is_send_buffer_full(&self) -> bool {
// To respect the receiver's window (remote_window) size and avoid sending too many unacknowledged packets, which may cause packet loss
// Simplified version: min(cwnd, rwnd)
self.seq.distance(self.get_last_received_ack()) >= MAX_UNACK.min(self.get_send_window() as u32)
self.seq.distance(self.get_last_received_ack()) >= self.max_unacked_bytes.min(self.get_send_window() as u32)
}
}

Expand All @@ -316,13 +334,13 @@ pub struct InflightPacket {
}

impl InflightPacket {
fn new(seq: SeqNum, payload: Vec<u8>) -> Self {
fn new(seq: SeqNum, payload: Vec<u8>, rto: Duration) -> Self {
Self {
seq,
payload,
send_time: std::time::Instant::now(),
retransmit_count: 0,
retransmit_timeout: RTO,
retransmit_timeout: rto,
}
}
pub(crate) fn contains_seq_num(&self, seq: SeqNum) -> bool {
Expand All @@ -339,7 +357,7 @@ mod tests {

#[test]
fn test_in_flight_packet() {
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]);
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50], RTO);

assert!(p.contains_seq_num((u32::MAX - 1).into()));
assert!(p.contains_seq_num(u32::MAX.into()));
Expand All @@ -352,7 +370,15 @@ mod tests {

#[test]
fn test_get_unordered_packets_with_max_bytes() {
let mut tcb = Tcb::new(SeqNum(1000), 1500);
let mut tcb = Tcb::new(
SeqNum(1000),
1500,
MAX_UNACK,
READ_BUFFER_SIZE,
MAX_COUNT_FOR_DUP_ACK,
RTO,
MAX_RETRANSMIT_COUNT,
);

// insert 3 consecutive packets
tcb.add_unordered_packet(SeqNum(1000), vec![1; 500]); // seq=1000, len=500
Expand Down Expand Up @@ -384,7 +410,15 @@ mod tests {

#[test]
fn test_update_inflight_packet_queue() {
let mut tcb = Tcb::new(SeqNum(1000), 1500);
let mut tcb = Tcb::new(
SeqNum(1000),
1500,
MAX_UNACK,
READ_BUFFER_SIZE,
MAX_COUNT_FOR_DUP_ACK,
RTO,
MAX_RETRANSMIT_COUNT,
);
tcb.seq = SeqNum(100); // setting the initial seq

// insert 3 consecutive packets
Expand All @@ -408,7 +442,15 @@ mod tests {

#[test]
fn test_update_inflight_packet_queue_cumulative_ack() {
let mut tcb = Tcb::new(SeqNum(1000), 1500);
let mut tcb = Tcb::new(
SeqNum(1000),
1500,
MAX_UNACK,
READ_BUFFER_SIZE,
MAX_COUNT_FOR_DUP_ACK,
RTO,
MAX_RETRANSMIT_COUNT,
);
tcb.seq = SeqNum(1000);

// Insert 3 consecutive packets
Expand All @@ -423,7 +465,15 @@ mod tests {

#[test]
fn test_retransmit_with_exponential_backoff() {
let mut tcb = Tcb::new(SeqNum(1000), 1500);
let mut tcb = Tcb::new(
SeqNum(1000),
1500,
MAX_UNACK,
READ_BUFFER_SIZE,
MAX_COUNT_FOR_DUP_ACK,
RTO,
MAX_RETRANSMIT_COUNT,
);

tcb.add_inflight_packet(vec![1; 500]).unwrap();

Expand Down
27 changes: 25 additions & 2 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
tcp_flags::{ACK, FIN, PSH, RST, SYN},
tcp_header_flags, tcp_header_fmt,
},
stream::tcb::{PacketType, Tcb, TcpState},
stream::tcb::{MAX_COUNT_FOR_DUP_ACK, MAX_RETRANSMIT_COUNT, MAX_UNACK, PacketType, READ_BUFFER_SIZE, RTO, Tcb, TcpState},
};
use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader, TcpOptionElement};
use std::{
Expand Down Expand Up @@ -43,6 +43,16 @@ pub struct TcpConfig {
pub timeout: Duration,
/// Timeout for the TIME_WAIT state. Default is 2 seconds.
pub two_msl: Duration,
/// Maximum number of unacknowledged bytes allowed in the send buffer.
pub max_unacked_bytes: u32,
/// Size of the read buffer for incoming data.
pub read_buffer_size: usize,
/// Maximum number of duplicate ACKs before triggering fast retransmission.
pub max_count_for_dup_ack: usize,
/// Retransmission timeout duration.
pub rto: std::time::Duration,
/// Maximum number of retransmissions before giving up.
pub max_retransmit_count: usize,
/// TCP options
pub options: Option<Vec<TcpOptions>>,
}
Expand All @@ -62,6 +72,11 @@ impl Default for TcpConfig {
close_wait_timeout: CLOSE_WAIT_TIMEOUT,
timeout: TIMEOUT,
two_msl: TWO_MSL,
max_unacked_bytes: MAX_UNACK,
read_buffer_size: READ_BUFFER_SIZE,
max_count_for_dup_ack: MAX_COUNT_FOR_DUP_ACK,
rto: RTO,
max_retransmit_count: MAX_RETRANSMIT_COUNT,
options: Default::default(),
}
}
Expand Down Expand Up @@ -169,7 +184,15 @@ impl IpStackTcpStream {
destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>,
config: Arc<TcpConfig>,
) -> Result<IpStackTcpStream, IpStackError> {
let tcb = Tcb::new(SeqNum(tcp.sequence_number), mtu);
let tcb = Tcb::new(
SeqNum(tcp.sequence_number),
mtu,
config.max_unacked_bytes,
config.read_buffer_size,
config.max_count_for_dup_ack,
config.rto,
config.max_retransmit_count,
);
let tuple = NetworkTuple::new(src_addr, dst_addr, true);
if !tcp.syn {
if !tcp.rst
Expand Down