use std::collections::{HashMap, VecDeque}; use std::io; use std::io::prelude::*; use std::net::Ipv4Addr; use std::sync::{Arc, Condvar, Mutex}; use std::thread; mod tcp; const SENDQUEUE_SIZE: usize = 1024; #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] struct Quad { src: (Ipv4Addr, u16), dst: (Ipv4Addr, u16), } #[derive(Default)] struct Foobar { manager: Mutex, pending_var: Condvar, rcv_var: Condvar, } type InterfaceHandle = Arc; pub struct Interface { ih: Option, jh: Option>>, } impl Drop for Interface { fn drop(&mut self) { self.ih.as_mut().unwrap().manager.lock().unwrap().terminate = true; drop(self.ih.take()); self.jh .take() .expect("interface dropped more than once") .join() .unwrap() .unwrap(); } } #[derive(Default)] struct ConnectionManager { terminate: bool, connections: HashMap, pending: HashMap>, } fn packet_loop(mut nic: tun_tap::Iface, ih: InterfaceHandle) -> io::Result<()> { let mut buf = [0u8; 1504]; loop { // we want to read from nic, but we want to make sure that we'll wake up when the next // timer has to be triggered! use std::os::unix::io::AsRawFd; let mut pfd = [nix::poll::PollFd::new( nic.as_raw_fd(), nix::poll::EventFlags::POLLIN, )]; let n = nix::poll::poll(&mut pfd[..], 10).map_err(|e| e.as_errno().unwrap())?; assert_ne!(n, -1); if n == 0 { let mut cmg = ih.manager.lock().unwrap(); for connection in cmg.connections.values_mut() { // XXX: don't die on errors? connection.on_tick(&mut nic)?; } continue; } assert_eq!(n, 1); let nbytes = nic.recv(&mut buf[..])?; // TODO: if self.terminate && Arc::get_strong_refs(ih) == 1; then tear down all connections and return. // if s/without_packet_info/new/: // // let _eth_flags = u16::from_be_bytes([buf[0], buf[1]]); // let eth_proto = u16::from_be_bytes([buf[2], buf[3]]); // if eth_proto != 0x0800 { // // not ipv4 // continue; // } // // and also include on send match etherparse::Ipv4HeaderSlice::from_slice(&buf[..nbytes]) { Ok(iph) => { let src = iph.source_addr(); let dst = iph.destination_addr(); if iph.protocol() != 0x06 { eprintln!("BAD PROTOCOL"); // not tcp continue; } match etherparse::TcpHeaderSlice::from_slice(&buf[iph.slice().len()..nbytes]) { Ok(tcph) => { use std::collections::hash_map::Entry; let datai = iph.slice().len() + tcph.slice().len(); let mut cmg = ih.manager.lock().unwrap(); let cm = &mut *cmg; let q = Quad { src: (src, tcph.source_port()), dst: (dst, tcph.destination_port()), }; match cm.connections.entry(q) { Entry::Occupied(mut c) => { eprintln!("got packet for known quad {:?}", q); let a = c.get_mut().on_packet( &mut nic, iph, tcph, &buf[datai..nbytes], )?; // TODO: compare before/after drop(cmg); if a.contains(tcp::Available::READ) { ih.rcv_var.notify_all() } if a.contains(tcp::Available::WRITE) { // TODO: ih.snd_var.notify_all() } } Entry::Vacant(e) => { eprintln!("got packet for unknown quad {:?}", q); if let Some(pending) = cm.pending.get_mut(&tcph.destination_port()) { eprintln!("listening, so accepting"); if let Some(c) = tcp::Connection::accept( &mut nic, iph, tcph, &buf[datai..nbytes], )? { e.insert(c); pending.push_back(q); drop(cmg); ih.pending_var.notify_all() } } } } } Err(e) => { eprintln!("ignoring weird tcp packet {:?}", e); } } } Err(e) => { // eprintln!("ignoring weird packet {:?}", e); } } } } impl Interface { pub fn new() -> io::Result { let nic = tun_tap::Iface::without_packet_info("tun0", tun_tap::Mode::Tun)?; let ih: InterfaceHandle = Arc::default(); let jh = { let ih = ih.clone(); thread::spawn(move || packet_loop(nic, ih)) }; Ok(Interface { ih: Some(ih), jh: Some(jh), }) } pub fn bind(&mut self, port: u16) -> io::Result { use std::collections::hash_map::Entry; let mut cm = self.ih.as_mut().unwrap().manager.lock().unwrap(); match cm.pending.entry(port) { Entry::Vacant(v) => { v.insert(VecDeque::new()); } Entry::Occupied(_) => { return Err(io::Error::new( io::ErrorKind::AddrInUse, "port already bound", )); } }; drop(cm); Ok(TcpListener { port, h: self.ih.as_mut().unwrap().clone(), }) } } pub struct TcpListener { port: u16, h: InterfaceHandle, } impl Drop for TcpListener { fn drop(&mut self) { let mut cm = self.h.manager.lock().unwrap(); let pending = cm .pending .remove(&self.port) .expect("port closed while listener still active"); for quad in pending { // TODO: terminate cm.connections[quad] unimplemented!(); } } } impl TcpListener { pub fn accept(&mut self) -> io::Result { let mut cm = self.h.manager.lock().unwrap(); loop { if let Some(quad) = cm .pending .get_mut(&self.port) .expect("port closed while listener still active") .pop_front() { return Ok(TcpStream { quad, h: self.h.clone(), }); } cm = self.h.pending_var.wait(cm).unwrap(); } } } pub struct TcpStream { quad: Quad, h: InterfaceHandle, } impl Drop for TcpStream { fn drop(&mut self) { let cm = self.h.manager.lock().unwrap(); // TODO: send FIN on cm.connections[quad] // TODO: _eventually_ remove self.quad from cm.connections } } impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut cm = self.h.manager.lock().unwrap(); loop { let c = cm.connections.get_mut(&self.quad).ok_or_else(|| { io::Error::new( io::ErrorKind::ConnectionAborted, "stream was terminated unexpectedly", ) })?; if c.is_rcv_closed() && c.incoming.is_empty() { // no more data to read, and no need to block, because there won't be any more return Ok(0); } if !c.incoming.is_empty() { let mut nread = 0; let (head, tail) = c.incoming.as_slices(); let hread = std::cmp::min(buf.len(), head.len()); buf[..hread].copy_from_slice(&head[..hread]); nread += hread; let tread = std::cmp::min(buf.len() - nread, tail.len()); buf[hread..(hread + tread)].copy_from_slice(&tail[..tread]); nread += tread; drop(c.incoming.drain(..nread)); return Ok(nread); } cm = self.h.rcv_var.wait(cm).unwrap(); } } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { let mut cm = self.h.manager.lock().unwrap(); let c = cm.connections.get_mut(&self.quad).ok_or_else(|| { io::Error::new( io::ErrorKind::ConnectionAborted, "stream was terminated unexpectedly", ) })?; if c.unacked.len() >= SENDQUEUE_SIZE { // TODO: block return Err(io::Error::new( io::ErrorKind::WouldBlock, "too many bytes buffered", )); } let nwrite = std::cmp::min(buf.len(), SENDQUEUE_SIZE - c.unacked.len()); c.unacked.extend(buf[..nwrite].iter()); Ok(nwrite) } fn flush(&mut self) -> io::Result<()> { let mut cm = self.h.manager.lock().unwrap(); let c = cm.connections.get_mut(&self.quad).ok_or_else(|| { io::Error::new( io::ErrorKind::ConnectionAborted, "stream was terminated unexpectedly", ) })?; if c.unacked.is_empty() { Ok(()) } else { // TODO: block Err(io::Error::new( io::ErrorKind::WouldBlock, "too many bytes buffered", )) } } } impl TcpStream { pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> { let mut cm = self.h.manager.lock().unwrap(); let c = cm.connections.get_mut(&self.quad).ok_or_else(|| { io::Error::new( io::ErrorKind::ConnectionAborted, "stream was terminated unexpectedly", ) })?; c.close() } }