feat: add rust tcp
This commit is contained in:
347
__network/rust_tcp/src/lib.rs
Normal file
347
__network/rust_tcp/src/lib.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::thread;
|
||||
|
||||
mod tcp;
|
||||
|
||||
const SENDQUEUE_SIZE: usize = 1024;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
struct Quad {
|
||||
src: (Ipv4Addr, u16),
|
||||
dst: (Ipv4Addr, u16),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Foobar {
|
||||
manager: Mutex<ConnectionManager>,
|
||||
pending_var: Condvar,
|
||||
rcv_var: Condvar,
|
||||
}
|
||||
|
||||
type InterfaceHandle = Arc<Foobar>;
|
||||
|
||||
pub struct Interface {
|
||||
ih: Option<InterfaceHandle>,
|
||||
jh: Option<thread::JoinHandle<io::Result<()>>>,
|
||||
}
|
||||
|
||||
impl Drop for Interface {
|
||||
fn drop(&mut self) {
|
||||
self.ih.as_mut().unwrap().manager.lock().unwrap().terminate = true;
|
||||
|
||||
drop(self.ih.take());
|
||||
self.jh
|
||||
.take()
|
||||
.expect("interface dropped more than once")
|
||||
.join()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ConnectionManager {
|
||||
terminate: bool,
|
||||
connections: HashMap<Quad, tcp::Connection>,
|
||||
pending: HashMap<u16, VecDeque<Quad>>,
|
||||
}
|
||||
|
||||
fn packet_loop(mut nic: tun_tap::Iface, ih: InterfaceHandle) -> io::Result<()> {
|
||||
let mut buf = [0u8; 1504];
|
||||
|
||||
loop {
|
||||
// we want to read from nic, but we want to make sure that we'll wake up when the next
|
||||
// timer has to be triggered!
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let mut pfd = [nix::poll::PollFd::new(
|
||||
nic.as_raw_fd(),
|
||||
nix::poll::EventFlags::POLLIN,
|
||||
)];
|
||||
let n = nix::poll::poll(&mut pfd[..], 10).map_err(|e| e.as_errno().unwrap())?;
|
||||
assert_ne!(n, -1);
|
||||
if n == 0 {
|
||||
let mut cmg = ih.manager.lock().unwrap();
|
||||
for connection in cmg.connections.values_mut() {
|
||||
// XXX: don't die on errors?
|
||||
connection.on_tick(&mut nic)?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
assert_eq!(n, 1);
|
||||
let nbytes = nic.recv(&mut buf[..])?;
|
||||
|
||||
// TODO: if self.terminate && Arc::get_strong_refs(ih) == 1; then tear down all connections and return.
|
||||
|
||||
// if s/without_packet_info/new/:
|
||||
//
|
||||
// let _eth_flags = u16::from_be_bytes([buf[0], buf[1]]);
|
||||
// let eth_proto = u16::from_be_bytes([buf[2], buf[3]]);
|
||||
// if eth_proto != 0x0800 {
|
||||
// // not ipv4
|
||||
// continue;
|
||||
// }
|
||||
//
|
||||
// and also include on send
|
||||
|
||||
match etherparse::Ipv4HeaderSlice::from_slice(&buf[..nbytes]) {
|
||||
Ok(iph) => {
|
||||
let src = iph.source_addr();
|
||||
let dst = iph.destination_addr();
|
||||
if iph.protocol() != 0x06 {
|
||||
eprintln!("BAD PROTOCOL");
|
||||
// not tcp
|
||||
continue;
|
||||
}
|
||||
|
||||
match etherparse::TcpHeaderSlice::from_slice(&buf[iph.slice().len()..nbytes]) {
|
||||
Ok(tcph) => {
|
||||
use std::collections::hash_map::Entry;
|
||||
let datai = iph.slice().len() + tcph.slice().len();
|
||||
let mut cmg = ih.manager.lock().unwrap();
|
||||
let cm = &mut *cmg;
|
||||
let q = Quad {
|
||||
src: (src, tcph.source_port()),
|
||||
dst: (dst, tcph.destination_port()),
|
||||
};
|
||||
|
||||
match cm.connections.entry(q) {
|
||||
Entry::Occupied(mut c) => {
|
||||
eprintln!("got packet for known quad {:?}", q);
|
||||
let a = c.get_mut().on_packet(
|
||||
&mut nic,
|
||||
iph,
|
||||
tcph,
|
||||
&buf[datai..nbytes],
|
||||
)?;
|
||||
|
||||
// TODO: compare before/after
|
||||
drop(cmg);
|
||||
if a.contains(tcp::Available::READ) {
|
||||
ih.rcv_var.notify_all()
|
||||
}
|
||||
if a.contains(tcp::Available::WRITE) {
|
||||
// TODO: ih.snd_var.notify_all()
|
||||
}
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
eprintln!("got packet for unknown quad {:?}", q);
|
||||
if let Some(pending) = cm.pending.get_mut(&tcph.destination_port())
|
||||
{
|
||||
eprintln!("listening, so accepting");
|
||||
if let Some(c) = tcp::Connection::accept(
|
||||
&mut nic,
|
||||
iph,
|
||||
tcph,
|
||||
&buf[datai..nbytes],
|
||||
)? {
|
||||
e.insert(c);
|
||||
pending.push_back(q);
|
||||
drop(cmg);
|
||||
ih.pending_var.notify_all()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("ignoring weird tcp packet {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// eprintln!("ignoring weird packet {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let nic = tun_tap::Iface::without_packet_info("tun0", tun_tap::Mode::Tun)?;
|
||||
|
||||
let ih: InterfaceHandle = Arc::default();
|
||||
|
||||
let jh = {
|
||||
let ih = ih.clone();
|
||||
thread::spawn(move || packet_loop(nic, ih))
|
||||
};
|
||||
|
||||
Ok(Interface {
|
||||
ih: Some(ih),
|
||||
jh: Some(jh),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bind(&mut self, port: u16) -> io::Result<TcpListener> {
|
||||
use std::collections::hash_map::Entry;
|
||||
let mut cm = self.ih.as_mut().unwrap().manager.lock().unwrap();
|
||||
match cm.pending.entry(port) {
|
||||
Entry::Vacant(v) => {
|
||||
v.insert(VecDeque::new());
|
||||
}
|
||||
Entry::Occupied(_) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"port already bound",
|
||||
));
|
||||
}
|
||||
};
|
||||
drop(cm);
|
||||
Ok(TcpListener {
|
||||
port,
|
||||
h: self.ih.as_mut().unwrap().clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpListener {
|
||||
port: u16,
|
||||
h: InterfaceHandle,
|
||||
}
|
||||
|
||||
impl Drop for TcpListener {
|
||||
fn drop(&mut self) {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
|
||||
let pending = cm
|
||||
.pending
|
||||
.remove(&self.port)
|
||||
.expect("port closed while listener still active");
|
||||
|
||||
for quad in pending {
|
||||
// TODO: terminate cm.connections[quad]
|
||||
unimplemented!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpListener {
|
||||
pub fn accept(&mut self) -> io::Result<TcpStream> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
loop {
|
||||
if let Some(quad) = cm
|
||||
.pending
|
||||
.get_mut(&self.port)
|
||||
.expect("port closed while listener still active")
|
||||
.pop_front()
|
||||
{
|
||||
return Ok(TcpStream {
|
||||
quad,
|
||||
h: self.h.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
cm = self.h.pending_var.wait(cm).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpStream {
|
||||
quad: Quad,
|
||||
h: InterfaceHandle,
|
||||
}
|
||||
|
||||
impl Drop for TcpStream {
|
||||
fn drop(&mut self) {
|
||||
let cm = self.h.manager.lock().unwrap();
|
||||
// TODO: send FIN on cm.connections[quad]
|
||||
// TODO: _eventually_ remove self.quad from cm.connections
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for TcpStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
loop {
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.is_rcv_closed() && c.incoming.is_empty() {
|
||||
// no more data to read, and no need to block, because there won't be any more
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if !c.incoming.is_empty() {
|
||||
let mut nread = 0;
|
||||
let (head, tail) = c.incoming.as_slices();
|
||||
let hread = std::cmp::min(buf.len(), head.len());
|
||||
buf[..hread].copy_from_slice(&head[..hread]);
|
||||
nread += hread;
|
||||
let tread = std::cmp::min(buf.len() - nread, tail.len());
|
||||
buf[hread..(hread + tread)].copy_from_slice(&tail[..tread]);
|
||||
nread += tread;
|
||||
drop(c.incoming.drain(..nread));
|
||||
return Ok(nread);
|
||||
}
|
||||
|
||||
cm = self.h.rcv_var.wait(cm).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TcpStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.unacked.len() >= SENDQUEUE_SIZE {
|
||||
// TODO: block
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"too many bytes buffered",
|
||||
));
|
||||
}
|
||||
|
||||
let nwrite = std::cmp::min(buf.len(), SENDQUEUE_SIZE - c.unacked.len());
|
||||
c.unacked.extend(buf[..nwrite].iter());
|
||||
|
||||
Ok(nwrite)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.unacked.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
// TODO: block
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"too many bytes buffered",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpStream {
|
||||
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
c.close()
|
||||
}
|
||||
}
|
||||
27
__network/rust_tcp/src/main.rs
Normal file
27
__network/rust_tcp/src/main.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use std::io::prelude::*;
|
||||
use std::{io, thread};
|
||||
|
||||
fn main() -> io::Result<()> {
|
||||
let mut i = trust::Interface::new()?;
|
||||
eprintln!("created interface");
|
||||
let mut listener = i.bind(8000)?;
|
||||
while let Ok(mut stream) = listener.accept() {
|
||||
eprintln!("got connection!");
|
||||
thread::spawn(move || {
|
||||
stream.write(b"hello from rust-tcp!\n").unwrap();
|
||||
stream.shutdown(std::net::Shutdown::Write).unwrap();
|
||||
loop {
|
||||
let mut buf = [0; 512];
|
||||
let n = stream.read(&mut buf[..]).unwrap();
|
||||
eprintln!("read {}b of data", n);
|
||||
if n == 0 {
|
||||
eprintln!("no more data!");
|
||||
break;
|
||||
} else {
|
||||
println!("{}", std::str::from_utf8(&buf[..n]).unwrap());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
568
__network/rust_tcp/src/tcp.rs
Normal file
568
__network/rust_tcp/src/tcp.rs
Normal file
@@ -0,0 +1,568 @@
|
||||
use bitflags::bitflags;
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::{io, time};
|
||||
|
||||
bitflags! {
|
||||
pub(crate) struct Available: u8 {
|
||||
const READ = 0b00000001;
|
||||
const WRITE = 0b00000010;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
//Listen,
|
||||
SynRcvd,
|
||||
Estab,
|
||||
FinWait1,
|
||||
FinWait2,
|
||||
TimeWait,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_synchronized(&self) -> bool {
|
||||
match *self {
|
||||
State::SynRcvd => false,
|
||||
State::Estab | State::FinWait1 | State::FinWait2 | State::TimeWait => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Connection {
|
||||
state: State,
|
||||
send: SendSequenceSpace,
|
||||
recv: RecvSequenceSpace,
|
||||
ip: etherparse::Ipv4Header,
|
||||
tcp: etherparse::TcpHeader,
|
||||
timers: Timers,
|
||||
|
||||
pub(crate) incoming: VecDeque<u8>,
|
||||
pub(crate) unacked: VecDeque<u8>,
|
||||
|
||||
pub(crate) closed: bool,
|
||||
closed_at: Option<u32>,
|
||||
}
|
||||
|
||||
struct Timers {
|
||||
send_times: BTreeMap<u32, time::Instant>,
|
||||
srtt: f64,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub(crate) fn is_rcv_closed(&self) -> bool {
|
||||
if let State::TimeWait = self.state {
|
||||
// TODO: any state after rcvd FIN, so also CLOSE-WAIT, LAST-ACK, CLOSED, CLOSING
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn availability(&self) -> Available {
|
||||
let mut a = Available::empty();
|
||||
if self.is_rcv_closed() || !self.incoming.is_empty() {
|
||||
a |= Available::READ;
|
||||
}
|
||||
// TODO: take into account self.state
|
||||
// TODO: set Available::WRITE
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
/// State of the Send Sequence Space (RFC 793 S3.2 F4)
|
||||
///
|
||||
/// ```
|
||||
/// 1 2 3 4
|
||||
/// ----------|----------|----------|----------
|
||||
/// SND.UNA SND.NXT SND.UNA
|
||||
/// +SND.WND
|
||||
///
|
||||
/// 1 - old sequence numbers which have been acknowledged
|
||||
/// 2 - sequence numbers of unacknowledged data
|
||||
/// 3 - sequence numbers allowed for new data transmission
|
||||
/// 4 - future sequence numbers which are not yet allowed
|
||||
/// ```
|
||||
struct SendSequenceSpace {
|
||||
/// send unacknowledged
|
||||
una: u32,
|
||||
/// send next
|
||||
nxt: u32,
|
||||
/// send window
|
||||
wnd: u16,
|
||||
/// send urgent pointer
|
||||
up: bool,
|
||||
/// segment sequence number used for last window update
|
||||
wl1: usize,
|
||||
/// segment acknowledgment number used for last window update
|
||||
wl2: usize,
|
||||
/// initial send sequence number
|
||||
iss: u32,
|
||||
}
|
||||
|
||||
/// State of the Receive Sequence Space (RFC 793 S3.2 F5)
|
||||
///
|
||||
/// ```
|
||||
/// 1 2 3
|
||||
/// ----------|----------|----------
|
||||
/// RCV.NXT RCV.NXT
|
||||
/// +RCV.WND
|
||||
///
|
||||
/// 1 - old sequence numbers which have been acknowledged
|
||||
/// 2 - sequence numbers allowed for new reception
|
||||
/// 3 - future sequence numbers which are not yet allowed
|
||||
/// ```
|
||||
struct RecvSequenceSpace {
|
||||
/// receive next
|
||||
nxt: u32,
|
||||
/// receive window
|
||||
wnd: u16,
|
||||
/// receive urgent pointer
|
||||
up: bool,
|
||||
/// initial receive sequence number
|
||||
irs: u32,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn accept<'a>(
|
||||
nic: &mut tun_tap::Iface,
|
||||
iph: etherparse::Ipv4HeaderSlice<'a>,
|
||||
tcph: etherparse::TcpHeaderSlice<'a>,
|
||||
data: &'a [u8],
|
||||
) -> io::Result<Option<Self>> {
|
||||
let buf = [0u8; 1500];
|
||||
if !tcph.syn() {
|
||||
// only expected SYN packet
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let iss = 0;
|
||||
let wnd = 1024;
|
||||
let mut c = Connection {
|
||||
timers: Timers {
|
||||
send_times: Default::default(),
|
||||
srtt: time::Duration::from_secs(1 * 60).as_secs_f64(),
|
||||
},
|
||||
state: State::SynRcvd,
|
||||
send: SendSequenceSpace {
|
||||
iss,
|
||||
una: iss,
|
||||
nxt: iss,
|
||||
wnd: wnd,
|
||||
up: false,
|
||||
|
||||
wl1: 0,
|
||||
wl2: 0,
|
||||
},
|
||||
recv: RecvSequenceSpace {
|
||||
irs: tcph.sequence_number(),
|
||||
nxt: tcph.sequence_number() + 1,
|
||||
wnd: tcph.window_size(),
|
||||
up: false,
|
||||
},
|
||||
tcp: etherparse::TcpHeader::new(tcph.destination_port(), tcph.source_port(), iss, wnd),
|
||||
ip: etherparse::Ipv4Header::new(
|
||||
0,
|
||||
64,
|
||||
etherparse::IpTrafficClass::Tcp,
|
||||
[
|
||||
iph.destination()[0],
|
||||
iph.destination()[1],
|
||||
iph.destination()[2],
|
||||
iph.destination()[3],
|
||||
],
|
||||
[
|
||||
iph.source()[0],
|
||||
iph.source()[1],
|
||||
iph.source()[2],
|
||||
iph.source()[3],
|
||||
],
|
||||
),
|
||||
|
||||
incoming: Default::default(),
|
||||
unacked: Default::default(),
|
||||
|
||||
closed: false,
|
||||
closed_at: None,
|
||||
};
|
||||
|
||||
// need to start establishing a connection
|
||||
c.tcp.syn = true;
|
||||
c.tcp.ack = true;
|
||||
c.write(nic, c.send.nxt, 0)?;
|
||||
Ok(Some(c))
|
||||
}
|
||||
|
||||
fn write(&mut self, nic: &mut tun_tap::Iface, seq: u32, mut limit: usize) -> io::Result<usize> {
|
||||
let mut buf = [0u8; 1500];
|
||||
self.tcp.sequence_number = seq;
|
||||
self.tcp.acknowledgment_number = self.recv.nxt;
|
||||
|
||||
// TODO: return +1 for SYN/FIN
|
||||
println!(
|
||||
"write(ack: {}, seq: {}, limit: {}) syn {:?} fin {:?}",
|
||||
self.recv.nxt - self.recv.irs, seq, limit, self.tcp.syn, self.tcp.fin,
|
||||
);
|
||||
|
||||
let mut offset = seq.wrapping_sub(self.send.una) as usize;
|
||||
// we need to special-case the two "virtual" bytes SYN and FIN
|
||||
if let Some(closed_at) = self.closed_at {
|
||||
if seq == closed_at.wrapping_add(1) {
|
||||
// trying to write following FIN
|
||||
offset = 0;
|
||||
limit = 0;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"using offset {} base {} in {:?}",
|
||||
offset,
|
||||
self.send.una,
|
||||
self.unacked.as_slices()
|
||||
);
|
||||
let (mut h, mut t) = self.unacked.as_slices();
|
||||
if h.len() >= offset {
|
||||
h = &h[offset..];
|
||||
} else {
|
||||
let skipped = h.len();
|
||||
h = &[];
|
||||
t = &t[(offset - skipped)..];
|
||||
}
|
||||
|
||||
let max_data = std::cmp::min(limit, h.len() + t.len());
|
||||
let size = std::cmp::min(
|
||||
buf.len(),
|
||||
self.tcp.header_len() as usize + self.ip.header_len() as usize + max_data,
|
||||
);
|
||||
self.ip
|
||||
.set_payload_len(size - self.ip.header_len() as usize);
|
||||
|
||||
// write out the headers and the payload
|
||||
use std::io::Write;
|
||||
let buf_len = buf.len();
|
||||
let mut unwritten = &mut buf[..];
|
||||
|
||||
self.ip.write(&mut unwritten);
|
||||
let ip_header_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// postpone writing the tcp header because we need the payload as one contiguous slice to calculate the tcp checksum
|
||||
unwritten = &mut unwritten[self.tcp.header_len() as usize..];
|
||||
let tcp_header_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// write out the payload
|
||||
let payload_bytes = {
|
||||
let mut written = 0;
|
||||
let mut limit = max_data;
|
||||
|
||||
// first, write as much as we can from h
|
||||
let p1l = std::cmp::min(limit, h.len());
|
||||
written += unwritten.write(&h[..p1l])?;
|
||||
limit -= written;
|
||||
|
||||
// then, write more (if we can) from t
|
||||
let p2l = std::cmp::min(limit, t.len());
|
||||
written += unwritten.write(&t[..p2l])?;
|
||||
written
|
||||
};
|
||||
let payload_ends_at = buf_len - unwritten.len();
|
||||
|
||||
// finally we can calculate the tcp checksum and write out the tcp header
|
||||
self.tcp.checksum = self
|
||||
.tcp
|
||||
.calc_checksum_ipv4(&self.ip, &buf[tcp_header_ends_at..payload_ends_at])
|
||||
.expect("failed to compute checksum");
|
||||
|
||||
let mut tcp_header_buf = &mut buf[ip_header_ends_at..tcp_header_ends_at];
|
||||
self.tcp.write(&mut tcp_header_buf);
|
||||
|
||||
let mut next_seq = seq.wrapping_add(payload_bytes as u32);
|
||||
if self.tcp.syn {
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
self.tcp.syn = false;
|
||||
}
|
||||
if self.tcp.fin {
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
self.tcp.fin = false;
|
||||
}
|
||||
if wrapping_lt(self.send.nxt, next_seq) {
|
||||
self.send.nxt = next_seq;
|
||||
}
|
||||
self.timers.send_times.insert(seq, time::Instant::now());
|
||||
|
||||
nic.send(&buf[..payload_ends_at])?;
|
||||
Ok(payload_bytes)
|
||||
}
|
||||
|
||||
fn send_rst(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
|
||||
self.tcp.rst = true;
|
||||
// TODO: fix sequence numbers here
|
||||
// If the incoming segment has an ACK field, the reset takes its
|
||||
// sequence number from the ACK field of the segment, otherwise the
|
||||
// reset has sequence number zero and the ACK field is set to the sum
|
||||
// of the sequence number and segment length of the incoming segment.
|
||||
// The connection remains in the same state.
|
||||
//
|
||||
// TODO: handle synchronized RST
|
||||
// 3. If the connection is in a synchronized state (ESTABLISHED,
|
||||
// FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT),
|
||||
// any unacceptable segment (out of window sequence number or
|
||||
// unacceptible acknowledgment number) must elicit only an empty
|
||||
// acknowledgment segment containing the current send-sequence number
|
||||
// and an acknowledgment indicating the next sequence number expected
|
||||
// to be received, and the connection remains in the same state.
|
||||
self.tcp.sequence_number = 0;
|
||||
self.tcp.acknowledgment_number = 0;
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn on_tick(&mut self, nic: &mut tun_tap::Iface) -> io::Result<()> {
|
||||
if let State::FinWait2 | State::TimeWait = self.state {
|
||||
// we have shutdown our write side and the other side acked, no need to (re)transmit anything
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// eprintln!("ON TICK: state {:?} una {} nxt {} unacked {:?}",
|
||||
// self.state, self.send.una, self.send.nxt, self.unacked);
|
||||
|
||||
let nunacked_data = self.closed_at.unwrap_or(self.send.nxt).wrapping_sub(self.send.una);
|
||||
let nunsent_data = self.unacked.len() as u32 - nunacked_data;
|
||||
|
||||
let waited_for = self
|
||||
.timers
|
||||
.send_times
|
||||
.range(self.send.una..)
|
||||
.next()
|
||||
.map(|t| t.1.elapsed());
|
||||
|
||||
let should_retransmit = if let Some(waited_for) = waited_for {
|
||||
waited_for > time::Duration::from_secs(1)
|
||||
&& waited_for.as_secs_f64() > 1.5 * self.timers.srtt
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if should_retransmit {
|
||||
let resend = std::cmp::min(self.unacked.len() as u32, self.send.wnd as u32);
|
||||
if resend < self.send.wnd as u32 && self.closed {
|
||||
// can we include the FIN?
|
||||
self.tcp.fin = true;
|
||||
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
|
||||
}
|
||||
self.write(nic, self.send.una, resend as usize)?;
|
||||
} else {
|
||||
// we should send new data if we have new data and space in the window
|
||||
if nunsent_data == 0 && self.closed_at.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let allowed = self.send.wnd as u32 - nunacked_data;
|
||||
if allowed == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let send = std::cmp::min(nunsent_data, allowed);
|
||||
if send < allowed && self.closed && self.closed_at.is_none() {
|
||||
self.tcp.fin = true;
|
||||
self.closed_at = Some(self.send.una.wrapping_add(self.unacked.len() as u32));
|
||||
}
|
||||
|
||||
self.write(nic, self.send.nxt, send as usize)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn on_packet<'a>(
|
||||
&mut self,
|
||||
nic: &mut tun_tap::Iface,
|
||||
iph: etherparse::Ipv4HeaderSlice<'a>,
|
||||
tcph: etherparse::TcpHeaderSlice<'a>,
|
||||
data: &'a [u8],
|
||||
) -> io::Result<Available> {
|
||||
// first, check that sequence numbers are valid (RFC 793 S3.3)
|
||||
let seqn = tcph.sequence_number();
|
||||
let mut slen = data.len() as u32;
|
||||
if tcph.fin() {
|
||||
slen += 1;
|
||||
};
|
||||
if tcph.syn() {
|
||||
slen += 1;
|
||||
};
|
||||
let wend = self.recv.nxt.wrapping_add(self.recv.wnd as u32);
|
||||
let okay = if slen == 0 {
|
||||
// zero-length segment has separate rules for acceptance
|
||||
if self.recv.wnd == 0 {
|
||||
if seqn != self.recv.nxt {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend) {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
if self.recv.wnd == 0 {
|
||||
false
|
||||
} else if !is_between_wrapped(self.recv.nxt.wrapping_sub(1), seqn, wend)
|
||||
&& !is_between_wrapped(
|
||||
self.recv.nxt.wrapping_sub(1),
|
||||
seqn.wrapping_add(slen - 1),
|
||||
wend,
|
||||
)
|
||||
{
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !okay {
|
||||
eprintln!("NOT OKAY");
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
return Ok(self.availability());
|
||||
}
|
||||
|
||||
if !tcph.ack() {
|
||||
if tcph.syn() {
|
||||
// got SYN part of initial handshake
|
||||
assert!(data.is_empty());
|
||||
self.recv.nxt = seqn.wrapping_add(1);
|
||||
}
|
||||
return Ok(self.availability());
|
||||
}
|
||||
|
||||
let ackn = tcph.acknowledgment_number();
|
||||
if let State::SynRcvd = self.state {
|
||||
if is_between_wrapped(
|
||||
self.send.una.wrapping_sub(1),
|
||||
ackn,
|
||||
self.send.nxt.wrapping_add(1),
|
||||
) {
|
||||
// must have ACKed our SYN, since we detected at least one acked byte,
|
||||
// and we have only sent one byte (the SYN).
|
||||
self.state = State::Estab;
|
||||
} else {
|
||||
// TODO: <SEQ=SEG.ACK><CTL=RST>
|
||||
}
|
||||
}
|
||||
|
||||
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
|
||||
if is_between_wrapped(self.send.una, ackn, self.send.nxt.wrapping_add(1)) {
|
||||
println!(
|
||||
"ack for {} (last: {}); prune in {:?}",
|
||||
ackn, self.send.una, self.unacked
|
||||
);
|
||||
if !self.unacked.is_empty() {
|
||||
let data_start = if self.send.una == self.send.iss {
|
||||
// send.una hasn't been updated yet with ACK for our SYN, so data starts just beyond it
|
||||
self.send.una.wrapping_add(1)
|
||||
} else {
|
||||
self.send.una
|
||||
};
|
||||
let acked_data_end = std::cmp::min(ackn.wrapping_sub(data_start) as usize, self.unacked.len());
|
||||
self.unacked.drain(..acked_data_end);
|
||||
|
||||
let old = std::mem::replace(&mut self.timers.send_times, BTreeMap::new());
|
||||
|
||||
let una = self.send.una;
|
||||
let mut srtt = &mut self.timers.srtt;
|
||||
self.timers
|
||||
.send_times
|
||||
.extend(old.into_iter().filter_map(|(seq, sent)| {
|
||||
if is_between_wrapped(una, seq, ackn) {
|
||||
*srtt = 0.8 * *srtt + (1.0 - 0.8) * sent.elapsed().as_secs_f64();
|
||||
None
|
||||
} else {
|
||||
Some((seq, sent))
|
||||
}
|
||||
}));
|
||||
}
|
||||
self.send.una = ackn;
|
||||
}
|
||||
|
||||
// TODO: if unacked empty and waiting flush, notify
|
||||
// TODO: update window
|
||||
}
|
||||
|
||||
if let State::FinWait1 = self.state {
|
||||
if let Some(closed_at) = self.closed_at {
|
||||
if self.send.una == closed_at.wrapping_add(1) {
|
||||
// our FIN has been ACKed!
|
||||
self.state = State::FinWait2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !data.is_empty() {
|
||||
if let State::Estab | State::FinWait1 | State::FinWait2 = self.state {
|
||||
let mut unread_data_at = self.recv.nxt.wrapping_sub(seqn) as usize;
|
||||
if unread_data_at > data.len() {
|
||||
// we must have received a re-transmitted FIN that we have already seen
|
||||
// nxt points to beyond the fin, but the fin is not in data!
|
||||
assert_eq!(unread_data_at, data.len() + 1);
|
||||
unread_data_at = 0;
|
||||
}
|
||||
self.incoming.extend(&data[unread_data_at..]);
|
||||
|
||||
/*
|
||||
Once the TCP takes responsibility for the data it advances
|
||||
RCV.NXT over the data accepted, and adjusts RCV.WND as
|
||||
apporopriate to the current buffer availability. The total of
|
||||
RCV.NXT and RCV.WND should not be reduced.
|
||||
*/
|
||||
self.recv.nxt = seqn.wrapping_add(data.len() as u32);
|
||||
|
||||
// Send an acknowledgment of the form: <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
|
||||
// TODO: maybe just tick to piggyback ack on data?
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
}
|
||||
}
|
||||
|
||||
if tcph.fin() {
|
||||
match self.state {
|
||||
State::FinWait2 => {
|
||||
// we're done with the connection!
|
||||
self.recv.nxt = self.recv.nxt.wrapping_add(1);
|
||||
self.write(nic, self.send.nxt, 0)?;
|
||||
self.state = State::TimeWait;
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.availability())
|
||||
}
|
||||
|
||||
pub(crate) fn close(&mut self) -> io::Result<()> {
|
||||
self.closed = true;
|
||||
match self.state {
|
||||
State::SynRcvd | State::Estab => {
|
||||
self.state = State::FinWait1;
|
||||
}
|
||||
State::FinWait1 | State::FinWait2 => {}
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closing",
|
||||
))
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn wrapping_lt(lhs: u32, rhs: u32) -> bool {
|
||||
// From RFC1323:
|
||||
// TCP determines if a data segment is "old" or "new" by testing
|
||||
// whether its sequence number is within 2**31 bytes of the left edge
|
||||
// of the window, and if it is not, discarding the data as "old". To
|
||||
// insure that new data is never mistakenly considered old and vice-
|
||||
// versa, the left edge of the sender's window has to be at most
|
||||
// 2**31 away from the right edge of the receiver's window.
|
||||
lhs.wrapping_sub(rhs) > (1 << 31)
|
||||
}
|
||||
|
||||
fn is_between_wrapped(start: u32, x: u32, end: u32) -> bool {
|
||||
wrapping_lt(start, x) && wrapping_lt(x, end)
|
||||
}
|
||||
Reference in New Issue
Block a user