feat: add rust tcp
This commit is contained in:
347
__network/rust_tcp/src/lib.rs
Normal file
347
__network/rust_tcp/src/lib.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::thread;
|
||||
|
||||
mod tcp;
|
||||
|
||||
const SENDQUEUE_SIZE: usize = 1024;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
struct Quad {
|
||||
src: (Ipv4Addr, u16),
|
||||
dst: (Ipv4Addr, u16),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Foobar {
|
||||
manager: Mutex<ConnectionManager>,
|
||||
pending_var: Condvar,
|
||||
rcv_var: Condvar,
|
||||
}
|
||||
|
||||
type InterfaceHandle = Arc<Foobar>;
|
||||
|
||||
pub struct Interface {
|
||||
ih: Option<InterfaceHandle>,
|
||||
jh: Option<thread::JoinHandle<io::Result<()>>>,
|
||||
}
|
||||
|
||||
impl Drop for Interface {
|
||||
fn drop(&mut self) {
|
||||
self.ih.as_mut().unwrap().manager.lock().unwrap().terminate = true;
|
||||
|
||||
drop(self.ih.take());
|
||||
self.jh
|
||||
.take()
|
||||
.expect("interface dropped more than once")
|
||||
.join()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ConnectionManager {
|
||||
terminate: bool,
|
||||
connections: HashMap<Quad, tcp::Connection>,
|
||||
pending: HashMap<u16, VecDeque<Quad>>,
|
||||
}
|
||||
|
||||
fn packet_loop(mut nic: tun_tap::Iface, ih: InterfaceHandle) -> io::Result<()> {
|
||||
let mut buf = [0u8; 1504];
|
||||
|
||||
loop {
|
||||
// we want to read from nic, but we want to make sure that we'll wake up when the next
|
||||
// timer has to be triggered!
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let mut pfd = [nix::poll::PollFd::new(
|
||||
nic.as_raw_fd(),
|
||||
nix::poll::EventFlags::POLLIN,
|
||||
)];
|
||||
let n = nix::poll::poll(&mut pfd[..], 10).map_err(|e| e.as_errno().unwrap())?;
|
||||
assert_ne!(n, -1);
|
||||
if n == 0 {
|
||||
let mut cmg = ih.manager.lock().unwrap();
|
||||
for connection in cmg.connections.values_mut() {
|
||||
// XXX: don't die on errors?
|
||||
connection.on_tick(&mut nic)?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
assert_eq!(n, 1);
|
||||
let nbytes = nic.recv(&mut buf[..])?;
|
||||
|
||||
// TODO: if self.terminate && Arc::get_strong_refs(ih) == 1; then tear down all connections and return.
|
||||
|
||||
// if s/without_packet_info/new/:
|
||||
//
|
||||
// let _eth_flags = u16::from_be_bytes([buf[0], buf[1]]);
|
||||
// let eth_proto = u16::from_be_bytes([buf[2], buf[3]]);
|
||||
// if eth_proto != 0x0800 {
|
||||
// // not ipv4
|
||||
// continue;
|
||||
// }
|
||||
//
|
||||
// and also include on send
|
||||
|
||||
match etherparse::Ipv4HeaderSlice::from_slice(&buf[..nbytes]) {
|
||||
Ok(iph) => {
|
||||
let src = iph.source_addr();
|
||||
let dst = iph.destination_addr();
|
||||
if iph.protocol() != 0x06 {
|
||||
eprintln!("BAD PROTOCOL");
|
||||
// not tcp
|
||||
continue;
|
||||
}
|
||||
|
||||
match etherparse::TcpHeaderSlice::from_slice(&buf[iph.slice().len()..nbytes]) {
|
||||
Ok(tcph) => {
|
||||
use std::collections::hash_map::Entry;
|
||||
let datai = iph.slice().len() + tcph.slice().len();
|
||||
let mut cmg = ih.manager.lock().unwrap();
|
||||
let cm = &mut *cmg;
|
||||
let q = Quad {
|
||||
src: (src, tcph.source_port()),
|
||||
dst: (dst, tcph.destination_port()),
|
||||
};
|
||||
|
||||
match cm.connections.entry(q) {
|
||||
Entry::Occupied(mut c) => {
|
||||
eprintln!("got packet for known quad {:?}", q);
|
||||
let a = c.get_mut().on_packet(
|
||||
&mut nic,
|
||||
iph,
|
||||
tcph,
|
||||
&buf[datai..nbytes],
|
||||
)?;
|
||||
|
||||
// TODO: compare before/after
|
||||
drop(cmg);
|
||||
if a.contains(tcp::Available::READ) {
|
||||
ih.rcv_var.notify_all()
|
||||
}
|
||||
if a.contains(tcp::Available::WRITE) {
|
||||
// TODO: ih.snd_var.notify_all()
|
||||
}
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
eprintln!("got packet for unknown quad {:?}", q);
|
||||
if let Some(pending) = cm.pending.get_mut(&tcph.destination_port())
|
||||
{
|
||||
eprintln!("listening, so accepting");
|
||||
if let Some(c) = tcp::Connection::accept(
|
||||
&mut nic,
|
||||
iph,
|
||||
tcph,
|
||||
&buf[datai..nbytes],
|
||||
)? {
|
||||
e.insert(c);
|
||||
pending.push_back(q);
|
||||
drop(cmg);
|
||||
ih.pending_var.notify_all()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("ignoring weird tcp packet {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// eprintln!("ignoring weird packet {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let nic = tun_tap::Iface::without_packet_info("tun0", tun_tap::Mode::Tun)?;
|
||||
|
||||
let ih: InterfaceHandle = Arc::default();
|
||||
|
||||
let jh = {
|
||||
let ih = ih.clone();
|
||||
thread::spawn(move || packet_loop(nic, ih))
|
||||
};
|
||||
|
||||
Ok(Interface {
|
||||
ih: Some(ih),
|
||||
jh: Some(jh),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bind(&mut self, port: u16) -> io::Result<TcpListener> {
|
||||
use std::collections::hash_map::Entry;
|
||||
let mut cm = self.ih.as_mut().unwrap().manager.lock().unwrap();
|
||||
match cm.pending.entry(port) {
|
||||
Entry::Vacant(v) => {
|
||||
v.insert(VecDeque::new());
|
||||
}
|
||||
Entry::Occupied(_) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"port already bound",
|
||||
));
|
||||
}
|
||||
};
|
||||
drop(cm);
|
||||
Ok(TcpListener {
|
||||
port,
|
||||
h: self.ih.as_mut().unwrap().clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpListener {
|
||||
port: u16,
|
||||
h: InterfaceHandle,
|
||||
}
|
||||
|
||||
impl Drop for TcpListener {
|
||||
fn drop(&mut self) {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
|
||||
let pending = cm
|
||||
.pending
|
||||
.remove(&self.port)
|
||||
.expect("port closed while listener still active");
|
||||
|
||||
for quad in pending {
|
||||
// TODO: terminate cm.connections[quad]
|
||||
unimplemented!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpListener {
|
||||
pub fn accept(&mut self) -> io::Result<TcpStream> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
loop {
|
||||
if let Some(quad) = cm
|
||||
.pending
|
||||
.get_mut(&self.port)
|
||||
.expect("port closed while listener still active")
|
||||
.pop_front()
|
||||
{
|
||||
return Ok(TcpStream {
|
||||
quad,
|
||||
h: self.h.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
cm = self.h.pending_var.wait(cm).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpStream {
|
||||
quad: Quad,
|
||||
h: InterfaceHandle,
|
||||
}
|
||||
|
||||
impl Drop for TcpStream {
|
||||
fn drop(&mut self) {
|
||||
let cm = self.h.manager.lock().unwrap();
|
||||
// TODO: send FIN on cm.connections[quad]
|
||||
// TODO: _eventually_ remove self.quad from cm.connections
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for TcpStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
loop {
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.is_rcv_closed() && c.incoming.is_empty() {
|
||||
// no more data to read, and no need to block, because there won't be any more
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if !c.incoming.is_empty() {
|
||||
let mut nread = 0;
|
||||
let (head, tail) = c.incoming.as_slices();
|
||||
let hread = std::cmp::min(buf.len(), head.len());
|
||||
buf[..hread].copy_from_slice(&head[..hread]);
|
||||
nread += hread;
|
||||
let tread = std::cmp::min(buf.len() - nread, tail.len());
|
||||
buf[hread..(hread + tread)].copy_from_slice(&tail[..tread]);
|
||||
nread += tread;
|
||||
drop(c.incoming.drain(..nread));
|
||||
return Ok(nread);
|
||||
}
|
||||
|
||||
cm = self.h.rcv_var.wait(cm).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TcpStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.unacked.len() >= SENDQUEUE_SIZE {
|
||||
// TODO: block
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"too many bytes buffered",
|
||||
));
|
||||
}
|
||||
|
||||
let nwrite = std::cmp::min(buf.len(), SENDQUEUE_SIZE - c.unacked.len());
|
||||
c.unacked.extend(buf[..nwrite].iter());
|
||||
|
||||
Ok(nwrite)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
if c.unacked.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
// TODO: block
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"too many bytes buffered",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpStream {
|
||||
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
|
||||
let mut cm = self.h.manager.lock().unwrap();
|
||||
let c = cm.connections.get_mut(&self.quad).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"stream was terminated unexpectedly",
|
||||
)
|
||||
})?;
|
||||
|
||||
c.close()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user