use bitflags::bitflags; use std::collections::{BTreeMap, VecDeque}; use std::{io, time}; bitflags! { pub(crate) struct Available: u8 { const READ = 0b00000001; const WRITE = 0b00000010; } } #[derive(Debug)] enum State { //Listen, SynRcvd, Estab, FinWait1, FinWait2, TimeWait, } impl State { fn is_synchronized(&self) -> bool { match *self { State::SynRcvd => false, State::Estab | State::FinWait1 | State::FinWait2 | State::TimeWait => true, } } } pub struct Connection { state: State, send: SendSequenceSpace, recv: RecvSequenceSpace, ip: etherparse::Ipv4Header, tcp: etherparse::TcpHeader, timers: Timers, pub(crate) incoming: VecDeque, pub(crate) unacked: VecDeque, pub(crate) closed: bool, closed_at: Option, } struct Timers { send_times: BTreeMap, srtt: f64, } impl Connection { pub(crate) fn is_rcv_closed(&self) -> bool { if let State::TimeWait = self.state { // TODO: any state after rcvd FIN, so also CLOSE-WAIT, LAST-ACK, CLOSED, CLOSING true } else { false } } fn availability(&self) -> Available { let mut a = Available::empty(); if self.is_rcv_closed() || !self.incoming.is_empty() { a |= Available::READ; } // TODO: take into account self.state // TODO: set Available::WRITE a } } /// State of the Send Sequence Space (RFC 793 S3.2 F4) /// /// ``` /// 1 2 3 4 /// ----------|----------|----------|---------- /// SND.UNA SND.NXT SND.UNA /// +SND.WND /// /// 1 - old sequence numbers which have been acknowledged /// 2 - sequence numbers of unacknowledged data /// 3 - sequence numbers allowed for new data transmission /// 4 - future sequence numbers which are not yet allowed /// ``` struct SendSequenceSpace { /// send unacknowledged una: u32, /// send next nxt: u32, /// send window wnd: u16, /// send urgent pointer up: bool, /// segment sequence number used for last window update wl1: usize, /// segment acknowledgment number used for last window update wl2: usize, /// initial send sequence number iss: u32, } /// State of the Receive Sequence Space (RFC 793 S3.2 F5) /// /// ``` /// 1 2 3 /// ----------|----------|---------- /// RCV.NXT RCV.NXT /// +RCV.WND /// /// 1 - old sequence numbers which have been acknowledged /// 2 - sequence numbers allowed for new reception /// 3 - future sequence numbers which are not yet allowed /// ``` struct RecvSequenceSpace { /// receive next nxt: u32, /// receive window wnd: u16, /// receive urgent pointer up: bool, /// initial receive sequence number irs: u32, } impl Connection { pub fn accept<'a>( nic: &mut tun_tap::Iface, iph: etherparse::Ipv4HeaderSlice<'a>, tcph: etherparse::TcpHeaderSlice<'a>, data: &'a [u8], ) -> io::Result> { let buf = [0u8; 1500]; if !tcph.syn() { // only expected SYN packet return Ok(None); } let iss = 0; let wnd = 1024; let mut c = Connection { timers: Timers { send_times: Default::default(), srtt: time::Duration::from_secs(1 * 60).as_secs_f64(), }, state: State::SynRcvd, send: SendSequenceSpace { iss, una: iss, nxt: iss, wnd: wnd, up: false, wl1: 0, wl2: 0, }, recv: RecvSequenceSpace { irs: tcph.sequence_number(), nxt: tcph.sequence_number() + 1, wnd: tcph.window_size(), up: false, }, tcp: etherparse::TcpHeader::new(tcph.destination_port(), tcph.source_port(), iss, wnd), ip: etherparse::Ipv4Header::new( 0, 64, etherparse::IpTrafficClass::Tcp, [ iph.destination()[0], iph.destination()[1], iph.destination()[2], iph.destination()[3], ], [ iph.source()[0], iph.source()[1], iph.source()[2], iph.source()[3], ], ), incoming: Default::default(), unacked: Default::default(), closed: false, closed_at: None, }; // need to start establishing a connection c.tcp.syn = true; c.tcp.ack = true; c.write(nic, c.send.nxt, 0)?; Ok(Some(c)) } fn write(&mut self, nic: &mut tun_tap::Iface, seq: u32, mut limit: usize) -> io::Result { let mut buf = [0u8; 1500]; self.tcp.sequence_number = seq; self.tcp.acknowledgment_number = self.recv.nxt; // TODO: return +1 for SYN/FIN println!( "write(ack: {}, seq: {}, limit: {}) syn {:?} fin {:?}", self.recv.nxt - self.recv.irs, seq, limit, self.tcp.syn, self.tcp.fin, ); let mut offset = seq.wrapping_sub(self.send.una) as usize; // we need to special-case the two "virtual" bytes SYN and FIN if let Some(closed_at) = self.closed_at { if seq == closed_at.wrapping_add(1) { // trying to write following FIN offset = 0; limit = 0; } } println!( "using offset {} base {} in {:?}", offset, self.send.una, self.unacked.as_slices() ); let (mut h, mut t) = self.unacked.as_slices(); if h.len() >= offset { h = &h[offset..]; } else { let skipped = h.len(); h = &[]; t = &t[(offset - skipped)..]; } let max_data = std::cmp::min(limit, h.len() + t.len()); let size = std::cmp::min( buf.len(), self.tcp.header_len() as usize + self.ip.header_len() as usize + max_data, ); self.ip .set_payload_len(size - self.ip.header_len() as usize); // write out the headers and the payload use std::io::Write; let buf_len = buf.len(); let mut unwritten = &mut buf[..]; self.ip.write(&mut unwritten); let ip_header_ends_at = buf_len - unwritten.len(); // postpone writing the tcp header because we need the payload as one contiguous slice to calculate the tcp checksum unwritten = &mut unwritten[self.tcp.header_len() as usize..]; let tcp_header_ends_at = buf_len - unwritten.len(); // write out the payload let payload_bytes = { let mut written = 0; let mut limit = max_data; // first, write as much as we can from h let p1l = std::cmp::min(limit, h.len()); written += unwritten.write(&h[..p1l])?; limit -= written; // then, write more (if we can) from t let p2l = std::cmp::min(limit, t.len()); written += unwritten.write(&t[..p2l])?; written }; let payload_ends_at = buf_len - unwritten.len(); // finally we can calculate the tcp checksum and write out the tcp header self.tcp.checksum = self .tcp .calc_checksum_ipv4(&self.ip, &buf[tcp_header_ends_at..payload_ends_at]) .expect("failed to compute checksum"); let mut tcp_header_buf = &mut buf[ip_header_ends_at..tcp_header_ends_at]; self.tcp.write(&mut tcp_header_buf); let mut next_seq = seq.wrapping_add(payload_bytes as u32); if self.tcp.syn { next_seq = next_seq.wrapping_add(1); self.tcp.syn = false; } if self.tcp.fin { next_seq = next_seq.wrapping_add(1); self.tcp.fin = false; } if wrapping_lt(self.send.nxt, next_seq) { self.send.nxt = next_seq; } self.timers.send_times.insert(seq, time::Instant::now()); nic.send(&buf[..payload_ends_at])?; Ok(payload_bytes) } fn send_rst(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> { self.tcp.rst = true; // TODO: fix sequence numbers here // If the incoming segment has an ACK field, the reset takes its // sequence number from the ACK field of the segment, otherwise the // reset has sequence number zero and the ACK field is set to the sum // of the sequence number and segment length of the incoming segment. // The connection remains in the same state. // // TODO: handle synchronized RST // 3. If the connection is in a synchronized state (ESTABLISHED, // FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT), // any unacceptable segment (out of window sequence number or // unacceptible acknowledgment number) must elicit only an empty // acknowledgment segment containing the current send-sequence number // and an acknowledgment indicating the next sequence number expected // to be received, and the connection remains in the same state. self.tcp.sequence_number = 0; self.tcp.acknowledgment_number = 0; self.write(nic, self.send.nxt, 0)?; Ok(()) } pub(crate) fn on_tick(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> { if let State::FinWait2 | State::TimeWait = self.state { // we have shutdown our write side and the other side acked, no need to (re)transmit anything return Ok(()); } // eprintln!("ON TICK: state {:?} una {} nxt {} unacked {:?}", // self.state, self.send.una, self.send.nxt, self.unacked); let nunacked_data = self.closed_at.unwrap_or(self.send.nxt).wrapping_sub(self.send.una); let nunsent_data = self.unacked.len() as u32 - nunacked_data; let waited_for = self .timers .send_times .range(self.send.una..) .next() .map(|t| t.1.elapsed()); let should_retransmit = if let Some(waited_for) = waited_for { waited_for > time::Duration::from_secs(1) && waited_for.as_secs_f64() > 1.5 * self.timers.srtt } else { false }; if should_retransmit { let resend = std::cmp::min(self.unacked.len() as u32, self.send.wnd as u32); if resend < self.send.wnd as u32 && self.closed { // can we include the FIN? self.tcp.fin = true; self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32)); } self.write(nic, self.send.una, resend as usize)?; } else { // we should send new data if we have new data and space in the window if nunsent_data == 0 && self.closed_at.is_some() { return Ok(()); } let allowed = self.send.wnd as u32 - nunacked_data; if allowed == 0 { return Ok(()); } let send = std::cmp::min(nunsent_data, allowed); if send < allowed && self.closed && self.closed_at.is_none() { self.tcp.fin = true; self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32)); } self.write(nic, self.send.nxt, send as usize)?; } Ok(()) } pub(crate) fn on_packet<'a>( &mut self, nic: &mut tun_tap::Iface, iph: etherparse::Ipv4HeaderSlice<'a>, tcph: etherparse::TcpHeaderSlice<'a>, data: &'a [u8], ) -> io::Result { // first, check that sequence numbers are valid (RFC 793 S3.3) let seqn = tcph.sequence_number(); let mut slen = data.len() as u32; if tcph.fin() { slen += 1; }; if tcph.syn() { slen += 1; }; let wend = self.recv.nxt.wrapping_add(self.recv.wnd as u32); let okay = if slen == 0 { // zero-length segment has separate rules for acceptance if self.recv.wnd == 0 { if seqn != self.recv.nxt { false } else { true } } else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend) { false } else { true } } else { if self.recv.wnd == 0 { false } else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend) && !is_between_wrapped( self.recv.nxt.wrapping_sub(1), seqn.wrapping_add(slen - 1), wend, ) { false } else { true } }; if !okay { eprintln!("NOT OKAY"); self.write(nic, self.send.nxt, 0)?; return Ok(self.availability()); } if !tcph.ack() { if tcph.syn() { // got SYN part of initial handshake assert!(data.is_empty()); self.recv.nxt = seqn.wrapping_add(1); } return Ok(self.availability()); } let ackn = tcph.acknowledgment_number(); if let State::SynRcvd = self.state { if is_between_wrapped( self.send.una.wrapping_sub(1), ackn, self.send.nxt.wrapping_add(1), ) { // must have ACKed our SYN, since we detected at least one acked byte, // and we have only sent one byte (the SYN). self.state = State::Estab; } else { // TODO: } } if let State::Estab | State::FinWait1 | State::FinWait2 = self.state { if is_between_wrapped(self.send.una, ackn, self.send.nxt.wrapping_add(1)) { println!( "ack for {} (last: {}); prune in {:?}", ackn, self.send.una, self.unacked ); if !self.unacked.is_empty() { let data_start = if self.send.una == self.send.iss { // send.una hasn't been updated yet with ACK for our SYN, so data starts just beyond it self.send.una.wrapping_add(1) } else { self.send.una }; let acked_data_end = std::cmp::min(ackn.wrapping_sub(data_start) as usize, self.unacked.len()); self.unacked.drain(..acked_data_end); let old = std::mem::replace(&mut self.timers.send_times, BTreeMap::new()); let una = self.send.una; let mut srtt = &mut self.timers.srtt; self.timers .send_times .extend(old.into_iter().filter_map(|(seq, sent)| { if is_between_wrapped(una, seq, ackn) { *srtt = 0.8 * *srtt + (1.0 - 0.8) * sent.elapsed().as_secs_f64(); None } else { Some((seq, sent)) } })); } self.send.una = ackn; } // TODO: if unacked empty and waiting flush, notify // TODO: update window } if let State::FinWait1 = self.state { if let Some(closed_at) = self.closed_at { if self.send.una == closed_at.wrapping_add(1) { // our FIN has been ACKed! self.state = State::FinWait2; } } } if !data.is_empty() { if let State::Estab | State::FinWait1 | State::FinWait2 = self.state { let mut unread_data_at = self.recv.nxt.wrapping_sub(seqn) as usize; if unread_data_at > data.len() { // we must have received a re-transmitted FIN that we have already seen // nxt points to beyond the fin, but the fin is not in data! assert_eq!(unread_data_at, data.len() + 1); unread_data_at = 0; } self.incoming.extend(&data[unread_data_at..]); /* Once the TCP takes responsibility for the data it advances RCV.NXT over the data accepted, and adjusts RCV.WND as apporopriate to the current buffer availability. The total of RCV.NXT and RCV.WND should not be reduced. */ self.recv.nxt = seqn.wrapping_add(data.len() as u32); // Send an acknowledgment of the form: // TODO: maybe just tick to piggyback ack on data? self.write(nic, self.send.nxt, 0)?; } } if tcph.fin() { match self.state { State::FinWait2 => { // we're done with the connection! self.recv.nxt = self.recv.nxt.wrapping_add(1); self.write(nic, self.send.nxt, 0)?; self.state = State::TimeWait; } _ => unimplemented!(), } } Ok(self.availability()) } pub(crate) fn close(&mut self) -> io::Result<()> { self.closed = true; match self.state { State::SynRcvd | State::Estab => { self.state = State::FinWait1; } State::FinWait1 | State::FinWait2 => {} _ => { return Err(io::Error::new( io::ErrorKind::NotConnected, "already closing", )) } }; Ok(()) } } fn wrapping_lt(lhs: u32, rhs: u32) -> bool { // From RFC1323: // TCP determines if a data segment is "old" or "new" by testing // whether its sequence number is within 2**31 bytes of the left edge // of the window, and if it is not, discarding the data as "old". To // insure that new data is never mistakenly considered old and vice- // versa, the left edge of the sender's window has to be at most // 2**31 away from the right edge of the receiver's window. lhs.wrapping_sub(rhs) > (1 << 31) } fn is_between_wrapped(start: u32, x: u32, end: u32) -> bool { wrapping_lt(start, x) && wrapping_lt(x, end) }