Files
simple-rust-tests/__network/rust_tcp/src/tcp.rs

569 lines
19 KiB
Rust

use bitflags::bitflags;
use std::collections::{BTreeMap, VecDeque};
use std::{io, time};
bitflags! {
pub(crate) struct Available: u8 {
const READ = 0b00000001;
const WRITE = 0b00000010;
}
}
#[derive(Debug)]
enum State {
//Listen,
SynRcvd,
Estab,
FinWait1,
FinWait2,
TimeWait,
}
impl State {
fn is_synchronized(&self) -> bool {
match *self {
State::SynRcvd => false,
State::Estab | State::FinWait1 | State::FinWait2 | State::TimeWait => true,
}
}
}
pub struct Connection {
state: State,
send: SendSequenceSpace,
recv: RecvSequenceSpace,
ip: etherparse::Ipv4Header,
tcp: etherparse::TcpHeader,
timers: Timers,
pub(crate) incoming: VecDeque<u8>,
pub(crate) unacked: VecDeque<u8>,
pub(crate) closed: bool,
closed_at: Option<u32>,
}
struct Timers {
send_times: BTreeMap<u32, time::Instant>,
srtt: f64,
}
impl Connection {
pub(crate) fn is_rcv_closed(&self) -> bool {
if let State::TimeWait = self.state {
// TODO: any state after rcvd FIN, so also CLOSE-WAIT, LAST-ACK, CLOSED, CLOSING
true
} else {
false
}
}
fn availability(&self) -> Available {
let mut a = Available::empty();
if self.is_rcv_closed() || !self.incoming.is_empty() {
a |= Available::READ;
}
// TODO: take into account self.state
// TODO: set Available::WRITE
a
}
}
/// State of the Send Sequence Space (RFC 793 S3.2 F4)
///
/// ```
/// 1 2 3 4
/// ----------|----------|----------|----------
/// SND.UNA SND.NXT SND.UNA
/// +SND.WND
///
/// 1 - old sequence numbers which have been acknowledged
/// 2 - sequence numbers of unacknowledged data
/// 3 - sequence numbers allowed for new data transmission
/// 4 - future sequence numbers which are not yet allowed
/// ```
struct SendSequenceSpace {
/// send unacknowledged
una: u32,
/// send next
nxt: u32,
/// send window
wnd: u16,
/// send urgent pointer
up: bool,
/// segment sequence number used for last window update
wl1: usize,
/// segment acknowledgment number used for last window update
wl2: usize,
/// initial send sequence number
iss: u32,
}
/// State of the Receive Sequence Space (RFC 793 S3.2 F5)
///
/// ```
/// 1 2 3
/// ----------|----------|----------
/// RCV.NXT RCV.NXT
/// +RCV.WND
///
/// 1 - old sequence numbers which have been acknowledged
/// 2 - sequence numbers allowed for new reception
/// 3 - future sequence numbers which are not yet allowed
/// ```
struct RecvSequenceSpace {
/// receive next
nxt: u32,
/// receive window
wnd: u16,
/// receive urgent pointer
up: bool,
/// initial receive sequence number
irs: u32,
}
impl Connection {
pub fn accept<'a>(
nic: &mut tun_tap::Iface,
iph: etherparse::Ipv4HeaderSlice<'a>,
tcph: etherparse::TcpHeaderSlice<'a>,
data: &'a [u8],
) -> io::Result<Option<Self>> {
let buf = [0u8; 1500];
if !tcph.syn() {
// only expected SYN packet
return Ok(None);
}
let iss = 0;
let wnd = 1024;
let mut c = Connection {
timers: Timers {
send_times: Default::default(),
srtt: time::Duration::from_secs(1 * 60).as_secs_f64(),
},
state: State::SynRcvd,
send: SendSequenceSpace {
iss,
una: iss,
nxt: iss,
wnd: wnd,
up: false,
wl1: 0,
wl2: 0,
},
recv: RecvSequenceSpace {
irs: tcph.sequence_number(),
nxt: tcph.sequence_number() + 1,
wnd: tcph.window_size(),
up: false,
},
tcp: etherparse::TcpHeader::new(tcph.destination_port(), tcph.source_port(), iss, wnd),
ip: etherparse::Ipv4Header::new(
0,
64,
etherparse::IpTrafficClass::Tcp,
[
iph.destination()[0],
iph.destination()[1],
iph.destination()[2],
iph.destination()[3],
],
[
iph.source()[0],
iph.source()[1],
iph.source()[2],
iph.source()[3],
],
),
incoming: Default::default(),
unacked: Default::default(),
closed: false,
closed_at: None,
};
// need to start establishing a connection
c.tcp.syn = true;
c.tcp.ack = true;
c.write(nic, c.send.nxt, 0)?;
Ok(Some(c))
}
fn write(&mut self, nic: &mut tun_tap::Iface, seq: u32, mut limit: usize) -> io::Result<usize> {
let mut buf = [0u8; 1500];
self.tcp.sequence_number = seq;
self.tcp.acknowledgment_number = self.recv.nxt;
// TODO: return +1 for SYN/FIN
println!(
"write(ack: {}, seq: {}, limit: {}) syn {:?} fin {:?}",
self.recv.nxt - self.recv.irs, seq, limit, self.tcp.syn, self.tcp.fin,
);
let mut offset = seq.wrapping_sub(self.send.una) as usize;
// we need to special-case the two "virtual" bytes SYN and FIN
if let Some(closed_at) = self.closed_at {
if seq == closed_at.wrapping_add(1) {
// trying to write following FIN
offset = 0;
limit = 0;
}
}
println!(
"using offset {} base {} in {:?}",
offset,
self.send.una,
self.unacked.as_slices()
);
let (mut h, mut t) = self.unacked.as_slices();
if h.len() >= offset {
h = &h[offset..];
} else {
let skipped = h.len();
h = &[];
t = &t[(offset - skipped)..];
}
let max_data = std::cmp::min(limit, h.len() + t.len());
let size = std::cmp::min(
buf.len(),
self.tcp.header_len() as usize + self.ip.header_len() as usize + max_data,
);
self.ip
.set_payload_len(size - self.ip.header_len() as usize);
// write out the headers and the payload
use std::io::Write;
let buf_len = buf.len();
let mut unwritten = &mut buf[..];
self.ip.write(&mut unwritten);
let ip_header_ends_at = buf_len - unwritten.len();
// postpone writing the tcp header because we need the payload as one contiguous slice to calculate the tcp checksum
unwritten = &mut unwritten[self.tcp.header_len() as usize..];
let tcp_header_ends_at = buf_len - unwritten.len();
// write out the payload
let payload_bytes = {
let mut written = 0;
let mut limit = max_data;
// first, write as much as we can from h
let p1l = std::cmp::min(limit, h.len());
written += unwritten.write(&h[..p1l])?;
limit -= written;
// then, write more (if we can) from t
let p2l = std::cmp::min(limit, t.len());
written += unwritten.write(&t[..p2l])?;
written
};
let payload_ends_at = buf_len - unwritten.len();
// finally we can calculate the tcp checksum and write out the tcp header
self.tcp.checksum = self
.tcp
.calc_checksum_ipv4(&self.ip, &buf[tcp_header_ends_at..payload_ends_at])
.expect("failed to compute checksum");
let mut tcp_header_buf = &mut buf[ip_header_ends_at..tcp_header_ends_at];
self.tcp.write(&mut tcp_header_buf);
let mut next_seq = seq.wrapping_add(payload_bytes as u32);
if self.tcp.syn {
next_seq = next_seq.wrapping_add(1);
self.tcp.syn = false;
}
if self.tcp.fin {
next_seq = next_seq.wrapping_add(1);
self.tcp.fin = false;
}
if wrapping_lt(self.send.nxt, next_seq) {
self.send.nxt = next_seq;
}
self.timers.send_times.insert(seq, time::Instant::now());
nic.send(&buf[..payload_ends_at])?;
Ok(payload_bytes)
}
fn send_rst(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
self.tcp.rst = true;
// TODO: fix sequence numbers here
// If the incoming segment has an ACK field, the reset takes its
// sequence number from the ACK field of the segment, otherwise the
// reset has sequence number zero and the ACK field is set to the sum
// of the sequence number and segment length of the incoming segment.
// The connection remains in the same state.
//
// TODO: handle synchronized RST
// 3. If the connection is in a synchronized state (ESTABLISHED,
// FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT),
// any unacceptable segment (out of window sequence number or
// unacceptible acknowledgment number) must elicit only an empty
// acknowledgment segment containing the current send-sequence number
// and an acknowledgment indicating the next sequence number expected
// to be received, and the connection remains in the same state.
self.tcp.sequence_number = 0;
self.tcp.acknowledgment_number = 0;
self.write(nic, self.send.nxt, 0)?;
Ok(())
}
pub(crate) fn on_tick(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
if let State::FinWait2 | State::TimeWait = self.state {
// we have shutdown our write side and the other side acked, no need to (re)transmit anything
return Ok(());
}
// eprintln!("ON TICK: state {:?} una {} nxt {} unacked {:?}",
// self.state, self.send.una, self.send.nxt, self.unacked);
let nunacked_data = self.closed_at.unwrap_or(self.send.nxt).wrapping_sub(self.send.una);
let nunsent_data = self.unacked.len() as u32 - nunacked_data;
let waited_for = self
.timers
.send_times
.range(self.send.una..)
.next()
.map(|t| t.1.elapsed());
let should_retransmit = if let Some(waited_for) = waited_for {
waited_for > time::Duration::from_secs(1)
&& waited_for.as_secs_f64() > 1.5 * self.timers.srtt
} else {
false
};
if should_retransmit {
let resend = std::cmp::min(self.unacked.len() as u32, self.send.wnd as u32);
if resend < self.send.wnd as u32 && self.closed {
// can we include the FIN?
self.tcp.fin = true;
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
}
self.write(nic, self.send.una, resend as usize)?;
} else {
// we should send new data if we have new data and space in the window
if nunsent_data == 0 && self.closed_at.is_some() {
return Ok(());
}
let allowed = self.send.wnd as u32 - nunacked_data;
if allowed == 0 {
return Ok(());
}
let send = std::cmp::min(nunsent_data, allowed);
if send < allowed && self.closed && self.closed_at.is_none() {
self.tcp.fin = true;
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
}
self.write(nic, self.send.nxt, send as usize)?;
}
Ok(())
}
pub(crate) fn on_packet<'a>(
&mut self,
nic: &mut tun_tap::Iface,
iph: etherparse::Ipv4HeaderSlice<'a>,
tcph: etherparse::TcpHeaderSlice<'a>,
data: &'a [u8],
) -> io::Result<Available> {
// first, check that sequence numbers are valid (RFC 793 S3.3)
let seqn = tcph.sequence_number();
let mut slen = data.len() as u32;
if tcph.fin() {
slen += 1;
};
if tcph.syn() {
slen += 1;
};
let wend = self.recv.nxt.wrapping_add(self.recv.wnd as u32);
let okay = if slen == 0 {
// zero-length segment has separate rules for acceptance
if self.recv.wnd == 0 {
if seqn != self.recv.nxt {
false
} else {
true
}
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend) {
false
} else {
true
}
} else {
if self.recv.wnd == 0 {
false
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend)
&& !is_between_wrapped(
self.recv.nxt.wrapping_sub(1),
seqn.wrapping_add(slen - 1),
wend,
)
{
false
} else {
true
}
};
if !okay {
eprintln!("NOT OKAY");
self.write(nic, self.send.nxt, 0)?;
return Ok(self.availability());
}
if !tcph.ack() {
if tcph.syn() {
// got SYN part of initial handshake
assert!(data.is_empty());
self.recv.nxt = seqn.wrapping_add(1);
}
return Ok(self.availability());
}
let ackn = tcph.acknowledgment_number();
if let State::SynRcvd = self.state {
if is_between_wrapped(
self.send.una.wrapping_sub(1),
ackn,
self.send.nxt.wrapping_add(1),
) {
// must have ACKed our SYN, since we detected at least one acked byte,
// and we have only sent one byte (the SYN).
self.state = State::Estab;
} else {
// TODO: <SEQ=SEG.ACK><CTL=RST>
}
}
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
if is_between_wrapped(self.send.una, ackn, self.send.nxt.wrapping_add(1)) {
println!(
"ack for {} (last: {}); prune in {:?}",
ackn, self.send.una, self.unacked
);
if !self.unacked.is_empty() {
let data_start = if self.send.una == self.send.iss {
// send.una hasn't been updated yet with ACK for our SYN, so data starts just beyond it
self.send.una.wrapping_add(1)
} else {
self.send.una
};
let acked_data_end = std::cmp::min(ackn.wrapping_sub(data_start) as usize, self.unacked.len());
self.unacked.drain(..acked_data_end);
let old = std::mem::replace(&mut self.timers.send_times, BTreeMap::new());
let una = self.send.una;
let mut srtt = &mut self.timers.srtt;
self.timers
.send_times
.extend(old.into_iter().filter_map(|(seq, sent)| {
if is_between_wrapped(una, seq, ackn) {
*srtt = 0.8 * *srtt + (1.0 - 0.8) * sent.elapsed().as_secs_f64();
None
} else {
Some((seq, sent))
}
}));
}
self.send.una = ackn;
}
// TODO: if unacked empty and waiting flush, notify
// TODO: update window
}
if let State::FinWait1 = self.state {
if let Some(closed_at) = self.closed_at {
if self.send.una == closed_at.wrapping_add(1) {
// our FIN has been ACKed!
self.state = State::FinWait2;
}
}
}
if !data.is_empty() {
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
let mut unread_data_at = self.recv.nxt.wrapping_sub(seqn) as usize;
if unread_data_at > data.len() {
// we must have received a re-transmitted FIN that we have already seen
// nxt points to beyond the fin, but the fin is not in data!
assert_eq!(unread_data_at, data.len() + 1);
unread_data_at = 0;
}
self.incoming.extend(&data[unread_data_at..]);
/*
Once the TCP takes responsibility for the data it advances
RCV.NXT over the data accepted, and adjusts RCV.WND as
apporopriate to the current buffer availability. The total of
RCV.NXT and RCV.WND should not be reduced.
*/
self.recv.nxt = seqn.wrapping_add(data.len() as u32);
// Send an acknowledgment of the form: <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
// TODO: maybe just tick to piggyback ack on data?
self.write(nic, self.send.nxt, 0)?;
}
}
if tcph.fin() {
match self.state {
State::FinWait2 => {
// we're done with the connection!
self.recv.nxt = self.recv.nxt.wrapping_add(1);
self.write(nic, self.send.nxt, 0)?;
self.state = State::TimeWait;
}
_ => unimplemented!(),
}
}
Ok(self.availability())
}
pub(crate) fn close(&mut self) -> io::Result<()> {
self.closed = true;
match self.state {
State::SynRcvd | State::Estab => {
self.state = State::FinWait1;
}
State::FinWait1 | State::FinWait2 => {}
_ => {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closing",
))
}
};
Ok(())
}
}
fn wrapping_lt(lhs: u32, rhs: u32) -> bool {
// From RFC1323:
// TCP determines if a data segment is "old" or "new" by testing
// whether its sequence number is within 2**31 bytes of the left edge
// of the window, and if it is not, discarding the data as "old". To
// insure that new data is never mistakenly considered old and vice-
// versa, the left edge of the sender's window has to be at most
// 2**31 away from the right edge of the receiver's window.
lhs.wrapping_sub(rhs) > (1 << 31)
}
fn is_between_wrapped(start: u32, x: u32, end: u32) -> bool {
wrapping_lt(start, x) && wrapping_lt(x, end)
}