feat: add rust tcp
This commit is contained in:
568
__network/rust_tcp/src/tcp.rs
Normal file
568
__network/rust_tcp/src/tcp.rs
Normal file
@@ -0,0 +1,568 @@
|
||||
use bitflags::bitflags;
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::{io, time};
|
||||
|
||||
bitflags! {
|
||||
pub(crate) struct Available: u8 {
|
||||
const READ = 0b00000001;
|
||||
const WRITE = 0b00000010;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
//Listen,
|
||||
SynRcvd,
|
||||
Estab,
|
||||
FinWait1,
|
||||
FinWait2,
|
||||
TimeWait,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_synchronized(&self) -> bool {
|
||||
match *self {
|
||||
State::SynRcvd => false,
|
||||
State::Estab | State::FinWait1 | State::FinWait2 | State::TimeWait => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Connection {
|
||||
state: State,
|
||||
send: SendSequenceSpace,
|
||||
recv: RecvSequenceSpace,
|
||||
ip: etherparse::Ipv4Header,
|
||||
tcp: etherparse::TcpHeader,
|
||||
timers: Timers,
|
||||
|
||||
pub(crate) incoming: VecDeque<u8>,
|
||||
pub(crate) unacked: VecDeque<u8>,
|
||||
|
||||
pub(crate) closed: bool,
|
||||
closed_at: Option<u32>,
|
||||
}
|
||||
|
||||
struct Timers {
|
||||
send_times: BTreeMap<u32, time::Instant>,
|
||||
srtt: f64,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub(crate) fn is_rcv_closed(&self) -> bool {
|
||||
if let State::TimeWait = self.state {
|
||||
// TODO: any state after rcvd FIN, so also CLOSE-WAIT, LAST-ACK, CLOSED, CLOSING
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn availability(&self) -> Available {
|
||||
let mut a = Available::empty();
|
||||
if self.is_rcv_closed() || !self.incoming.is_empty() {
|
||||
a |= Available::READ;
|
||||
}
|
||||
// TODO: take into account self.state
|
||||
// TODO: set Available::WRITE
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
/// State of the Send Sequence Space (RFC 793 S3.2 F4)
|
||||
///
|
||||
/// ```
|
||||
/// 1 2 3 4
|
||||
/// ----------|----------|----------|----------
|
||||
/// SND.UNA SND.NXT SND.UNA
|
||||
/// +SND.WND
|
||||
///
|
||||
/// 1 - old sequence numbers which have been acknowledged
|
||||
/// 2 - sequence numbers of unacknowledged data
|
||||
/// 3 - sequence numbers allowed for new data transmission
|
||||
/// 4 - future sequence numbers which are not yet allowed
|
||||
/// ```
|
||||
struct SendSequenceSpace {
|
||||
/// send unacknowledged
|
||||
una: u32,
|
||||
/// send next
|
||||
nxt: u32,
|
||||
/// send window
|
||||
wnd: u16,
|
||||
/// send urgent pointer
|
||||
up: bool,
|
||||
/// segment sequence number used for last window update
|
||||
wl1: usize,
|
||||
/// segment acknowledgment number used for last window update
|
||||
wl2: usize,
|
||||
/// initial send sequence number
|
||||
iss: u32,
|
||||
}
|
||||
|
||||
/// State of the Receive Sequence Space (RFC 793 S3.2 F5)
|
||||
///
|
||||
/// ```
|
||||
/// 1 2 3
|
||||
/// ----------|----------|----------
|
||||
/// RCV.NXT RCV.NXT
|
||||
/// +RCV.WND
|
||||
///
|
||||
/// 1 - old sequence numbers which have been acknowledged
|
||||
/// 2 - sequence numbers allowed for new reception
|
||||
/// 3 - future sequence numbers which are not yet allowed
|
||||
/// ```
|
||||
struct RecvSequenceSpace {
|
||||
/// receive next
|
||||
nxt: u32,
|
||||
/// receive window
|
||||
wnd: u16,
|
||||
/// receive urgent pointer
|
||||
up: bool,
|
||||
/// initial receive sequence number
|
||||
irs: u32,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn accept<'a>(
|
||||
nic: &mut tun_tap::Iface,
|
||||
iph: etherparse::Ipv4HeaderSlice<'a>,
|
||||
tcph: etherparse::TcpHeaderSlice<'a>,
|
||||
data: &'a [u8],
|
||||
) -> io::Result<Option<Self>> {
|
||||
let buf = [0u8; 1500];
|
||||
if !tcph.syn() {
|
||||
// only expected SYN packet
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let iss = 0;
|
||||
let wnd = 1024;
|
||||
let mut c = Connection {
|
||||
timers: Timers {
|
||||
send_times: Default::default(),
|
||||
srtt: time::Duration::from_secs(1 * 60).as_secs_f64(),
|
||||
},
|
||||
state: State::SynRcvd,
|
||||
send: SendSequenceSpace {
|
||||
iss,
|
||||
una: iss,
|
||||
nxt: iss,
|
||||
wnd: wnd,
|
||||
up: false,
|
||||
|
||||
wl1: 0,
|
||||
wl2: 0,
|
||||
},
|
||||
recv: RecvSequenceSpace {
|
||||
irs: tcph.sequence_number(),
|
||||
nxt: tcph.sequence_number() + 1,
|
||||
wnd: tcph.window_size(),
|
||||
up: false,
|
||||
},
|
||||
tcp: etherparse::TcpHeader::new(tcph.destination_port(), tcph.source_port(), iss, wnd),
|
||||
ip: etherparse::Ipv4Header::new(
|
||||
0,
|
||||
64,
|
||||
etherparse::IpTrafficClass::Tcp,
|
||||
[
|
||||
iph.destination()[0],
|
||||
iph.destination()[1],
|
||||
iph.destination()[2],
|
||||
iph.destination()[3],
|
||||
],
|
||||
[
|
||||
iph.source()[0],
|
||||
iph.source()[1],
|
||||
iph.source()[2],
|
||||
iph.source()[3],
|
||||
],
|
||||
),
|
||||
|
||||
incoming: Default::default(),
|
||||
unacked: Default::default(),
|
||||
|
||||
closed: false,
|
||||
closed_at: None,
|
||||
};
|
||||
|
||||
// need to start establishing a connection
|
||||
c.tcp.syn = true;
|
||||
c.tcp.ack = true;
|
||||
c.write(nic, c.send.nxt, 0)?;
|
||||
Ok(Some(c))
|
||||
}
|
||||
|
||||
fn write(&mut self, nic: &mut tun_tap::Iface, seq: u32, mut limit: usize) -> io::Result<usize> {
|
||||
let mut buf = [0u8; 1500];
|
||||
self.tcp.sequence_number = seq;
|
||||
self.tcp.acknowledgment_number = self.recv.nxt;
|
||||
|
||||
// TODO: return +1 for SYN/FIN
|
||||
println!(
|
||||
"write(ack: {}, seq: {}, limit: {}) syn {:?} fin {:?}",
|
||||
self.recv.nxt - self.recv.irs, seq, limit, self.tcp.syn, self.tcp.fin,
|
||||
);
|
||||
|
||||
let mut offset = seq.wrapping_sub(self.send.una) as usize;
|
||||
// we need to special-case the two "virtual" bytes SYN and FIN
|
||||
if let Some(closed_at) = self.closed_at {
|
||||
if seq == closed_at.wrapping_add(1) {
|
||||
// trying to write following FIN
|
||||
offset = 0;
|
||||
limit = 0;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"using offset {} base {} in {:?}",
|
||||
offset,
|
||||
self.send.una,
|
||||
self.unacked.as_slices()
|
||||
);
|
||||
let (mut h, mut t) = self.unacked.as_slices();
|
||||
if h.len() >= offset {
|
||||
h = &h[offset..];
|
||||
} else {
|
||||
let skipped = h.len();
|
||||
h = &[];
|
||||
t = &t[(offset - skipped)..];
|
||||
}
|
||||
|
||||
let max_data = std::cmp::min(limit, h.len() + t.len());
|
||||
let size = std::cmp::min(
|
||||
buf.len(),
|
||||
self.tcp.header_len() as usize + self.ip.header_len() as usize + max_data,
|
||||
);
|
||||
self.ip
|
||||
.set_payload_len(size - self.ip.header_len() as usize);
|
||||
|
||||
// write out the headers and the payload
|
||||
use std::io::Write;
|
||||
let buf_len = buf.len();
|
||||
let mut unwritten = &mut buf[..];
|
||||
|
||||
self.ip.write(&mut unwritten);
|
||||
let ip_header_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// postpone writing the tcp header because we need the payload as one contiguous slice to calculate the tcp checksum
|
||||
unwritten = &mut unwritten[self.tcp.header_len() as usize..];
|
||||
let tcp_header_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// write out the payload
|
||||
let payload_bytes = {
|
||||
let mut written = 0;
|
||||
let mut limit = max_data;
|
||||
|
||||
// first, write as much as we can from h
|
||||
let p1l = std::cmp::min(limit, h.len());
|
||||
written += unwritten.write(&h[..p1l])?;
|
||||
limit -= written;
|
||||
|
||||
// then, write more (if we can) from t
|
||||
let p2l = std::cmp::min(limit, t.len());
|
||||
written += unwritten.write(&t[..p2l])?;
|
||||
written
|
||||
};
|
||||
let payload_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// finally we can calculate the tcp checksum and write out the tcp header
|
||||
self.tcp.checksum = self
|
||||
.tcp
|
||||
.calc_checksum_ipv4(&self.ip, &buf[tcp_header_ends_at..payload_ends_at])
|
||||
.expect("failed to compute checksum");
|
||||
|
||||
let mut tcp_header_buf = &mut buf[ip_header_ends_at..tcp_header_ends_at];
|
||||
self.tcp.write(&mut tcp_header_buf);
|
||||
|
||||
let mut next_seq = seq.wrapping_add(payload_bytes as u32);
|
||||
if self.tcp.syn {
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
self.tcp.syn = false;
|
||||
}
|
||||
if self.tcp.fin {
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
self.tcp.fin = false;
|
||||
}
|
||||
if wrapping_lt(self.send.nxt, next_seq) {
|
||||
self.send.nxt = next_seq;
|
||||
}
|
||||
self.timers.send_times.insert(seq, time::Instant::now());
|
||||
|
||||
nic.send(&buf[..payload_ends_at])?;
|
||||
Ok(payload_bytes)
|
||||
}
|
||||
|
||||
fn send_rst(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
|
||||
self.tcp.rst = true;
|
||||
// TODO: fix sequence numbers here
|
||||
// If the incoming segment has an ACK field, the reset takes its
|
||||
// sequence number from the ACK field of the segment, otherwise the
|
||||
// reset has sequence number zero and the ACK field is set to the sum
|
||||
// of the sequence number and segment length of the incoming segment.
|
||||
// The connection remains in the same state.
|
||||
//
|
||||
// TODO: handle synchronized RST
|
||||
// 3. If the connection is in a synchronized state (ESTABLISHED,
|
||||
// FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT),
|
||||
// any unacceptable segment (out of window sequence number or
|
||||
// unacceptible acknowledgment number) must elicit only an empty
|
||||
// acknowledgment segment containing the current send-sequence number
|
||||
// and an acknowledgment indicating the next sequence number expected
|
||||
// to be received, and the connection remains in the same state.
|
||||
self.tcp.sequence_number = 0;
|
||||
self.tcp.acknowledgment_number = 0;
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn on_tick(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
|
||||
if let State::FinWait2 | State::TimeWait = self.state {
|
||||
// we have shutdown our write side and the other side acked, no need to (re)transmit anything
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// eprintln!("ON TICK: state {:?} una {} nxt {} unacked {:?}",
|
||||
// self.state, self.send.una, self.send.nxt, self.unacked);
|
||||
|
||||
let nunacked_data = self.closed_at.unwrap_or(self.send.nxt).wrapping_sub(self.send.una);
|
||||
let nunsent_data = self.unacked.len() as u32 - nunacked_data;
|
||||
|
||||
let waited_for = self
|
||||
.timers
|
||||
.send_times
|
||||
.range(self.send.una..)
|
||||
.next()
|
||||
.map(|t| t.1.elapsed());
|
||||
|
||||
let should_retransmit = if let Some(waited_for) = waited_for {
|
||||
waited_for > time::Duration::from_secs(1)
|
||||
&& waited_for.as_secs_f64() > 1.5 * self.timers.srtt
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if should_retransmit {
|
||||
let resend = std::cmp::min(self.unacked.len() as u32, self.send.wnd as u32);
|
||||
if resend < self.send.wnd as u32 && self.closed {
|
||||
// can we include the FIN?
|
||||
self.tcp.fin = true;
|
||||
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
|
||||
}
|
||||
self.write(nic, self.send.una, resend as usize)?;
|
||||
} else {
|
||||
// we should send new data if we have new data and space in the window
|
||||
if nunsent_data == 0 && self.closed_at.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let allowed = self.send.wnd as u32 - nunacked_data;
|
||||
if allowed == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let send = std::cmp::min(nunsent_data, allowed);
|
||||
if send < allowed && self.closed && self.closed_at.is_none() {
|
||||
self.tcp.fin = true;
|
||||
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
|
||||
}
|
||||
|
||||
self.write(nic, self.send.nxt, send as usize)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn on_packet<'a>(
|
||||
&mut self,
|
||||
nic: &mut tun_tap::Iface,
|
||||
iph: etherparse::Ipv4HeaderSlice<'a>,
|
||||
tcph: etherparse::TcpHeaderSlice<'a>,
|
||||
data: &'a [u8],
|
||||
) -> io::Result<Available> {
|
||||
// first, check that sequence numbers are valid (RFC 793 S3.3)
|
||||
let seqn = tcph.sequence_number();
|
||||
let mut slen = data.len() as u32;
|
||||
if tcph.fin() {
|
||||
slen += 1;
|
||||
};
|
||||
if tcph.syn() {
|
||||
slen += 1;
|
||||
};
|
||||
let wend = self.recv.nxt.wrapping_add(self.recv.wnd as u32);
|
||||
let okay = if slen == 0 {
|
||||
// zero-length segment has separate rules for acceptance
|
||||
if self.recv.wnd == 0 {
|
||||
if seqn != self.recv.nxt {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend) {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
if self.recv.wnd == 0 {
|
||||
false
|
||||
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend)
|
||||
&& !is_between_wrapped(
|
||||
self.recv.nxt.wrapping_sub(1),
|
||||
seqn.wrapping_add(slen - 1),
|
||||
wend,
|
||||
)
|
||||
{
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !okay {
|
||||
eprintln!("NOT OKAY");
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
return Ok(self.availability());
|
||||
}
|
||||
|
||||
if !tcph.ack() {
|
||||
if tcph.syn() {
|
||||
// got SYN part of initial handshake
|
||||
assert!(data.is_empty());
|
||||
self.recv.nxt = seqn.wrapping_add(1);
|
||||
}
|
||||
return Ok(self.availability());
|
||||
}
|
||||
|
||||
let ackn = tcph.acknowledgment_number();
|
||||
if let State::SynRcvd = self.state {
|
||||
if is_between_wrapped(
|
||||
self.send.una.wrapping_sub(1),
|
||||
ackn,
|
||||
self.send.nxt.wrapping_add(1),
|
||||
) {
|
||||
// must have ACKed our SYN, since we detected at least one acked byte,
|
||||
// and we have only sent one byte (the SYN).
|
||||
self.state = State::Estab;
|
||||
} else {
|
||||
// TODO: <SEQ=SEG.ACK><CTL=RST>
|
||||
}
|
||||
}
|
||||
|
||||
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
|
||||
if is_between_wrapped(self.send.una, ackn, self.send.nxt.wrapping_add(1)) {
|
||||
println!(
|
||||
"ack for {} (last: {}); prune in {:?}",
|
||||
ackn, self.send.una, self.unacked
|
||||
);
|
||||
if !self.unacked.is_empty() {
|
||||
let data_start = if self.send.una == self.send.iss {
|
||||
// send.una hasn't been updated yet with ACK for our SYN, so data starts just beyond it
|
||||
self.send.una.wrapping_add(1)
|
||||
} else {
|
||||
self.send.una
|
||||
};
|
||||
let acked_data_end = std::cmp::min(ackn.wrapping_sub(data_start) as usize, self.unacked.len());
|
||||
self.unacked.drain(..acked_data_end);
|
||||
|
||||
let old = std::mem::replace(&mut self.timers.send_times, BTreeMap::new());
|
||||
|
||||
let una = self.send.una;
|
||||
let mut srtt = &mut self.timers.srtt;
|
||||
self.timers
|
||||
.send_times
|
||||
.extend(old.into_iter().filter_map(|(seq, sent)| {
|
||||
if is_between_wrapped(una, seq, ackn) {
|
||||
*srtt = 0.8 * *srtt + (1.0 - 0.8) * sent.elapsed().as_secs_f64();
|
||||
None
|
||||
} else {
|
||||
Some((seq, sent))
|
||||
}
|
||||
}));
|
||||
}
|
||||
self.send.una = ackn;
|
||||
}
|
||||
|
||||
// TODO: if unacked empty and waiting flush, notify
|
||||
// TODO: update window
|
||||
}
|
||||
|
||||
if let State::FinWait1 = self.state {
|
||||
if let Some(closed_at) = self.closed_at {
|
||||
if self.send.una == closed_at.wrapping_add(1) {
|
||||
// our FIN has been ACKed!
|
||||
self.state = State::FinWait2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !data.is_empty() {
|
||||
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
|
||||
let mut unread_data_at = self.recv.nxt.wrapping_sub(seqn) as usize;
|
||||
if unread_data_at > data.len() {
|
||||
// we must have received a re-transmitted FIN that we have already seen
|
||||
// nxt points to beyond the fin, but the fin is not in data!
|
||||
assert_eq!(unread_data_at, data.len() + 1);
|
||||
unread_data_at = 0;
|
||||
}
|
||||
self.incoming.extend(&data[unread_data_at..]);
|
||||
|
||||
/*
|
||||
Once the TCP takes responsibility for the data it advances
|
||||
RCV.NXT over the data accepted, and adjusts RCV.WND as
|
||||
apporopriate to the current buffer availability. The total of
|
||||
RCV.NXT and RCV.WND should not be reduced.
|
||||
*/
|
||||
self.recv.nxt = seqn.wrapping_add(data.len() as u32);
|
||||
|
||||
// Send an acknowledgment of the form: <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
|
||||
// TODO: maybe just tick to piggyback ack on data?
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
}
|
||||
}
|
||||
|
||||
if tcph.fin() {
|
||||
match self.state {
|
||||
State::FinWait2 => {
|
||||
// we're done with the connection!
|
||||
self.recv.nxt = self.recv.nxt.wrapping_add(1);
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
self.state = State::TimeWait;
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.availability())
|
||||
}
|
||||
|
||||
pub(crate) fn close(&mut self) -> io::Result<()> {
|
||||
self.closed = true;
|
||||
match self.state {
|
||||
State::SynRcvd | State::Estab => {
|
||||
self.state = State::FinWait1;
|
||||
}
|
||||
State::FinWait1 | State::FinWait2 => {}
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closing",
|
||||
))
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn wrapping_lt(lhs: u32, rhs: u32) -> bool {
|
||||
// From RFC1323:
|
||||
// TCP determines if a data segment is "old" or "new" by testing
|
||||
// whether its sequence number is within 2**31 bytes of the left edge
|
||||
// of the window, and if it is not, discarding the data as "old". To
|
||||
// insure that new data is never mistakenly considered old and vice-
|
||||
// versa, the left edge of the sender's window has to be at most
|
||||
// 2**31 away from the right edge of the receiver's window.
|
||||
lhs.wrapping_sub(rhs) > (1 << 31)
|
||||
}
|
||||
|
||||
fn is_between_wrapped(start: u32, x: u32, end: u32) -> bool {
|
||||
wrapping_lt(start, x) && wrapping_lt(x, end)
|
||||
}
|
||||
Reference in New Issue
Block a user