extern crate rand; #[macro_use] extern crate rust_util; use std::collections::HashMap; use std::env; use std::net::UdpSocket; use std::sync::Arc; use std::sync::atomic::{ AtomicBool, Ordering }; use std::sync::mpsc::channel; use std::thread; use std::time::Duration; use clap::{ Arg, App }; const LOCAL_ADDR: &str = "127.0.0.1"; const TIMEOUT: u64 = 3 * 60 * 100; //3 minutes fn main() { let matches = App::new("Simple Rust UDP Proxy") .version(env!("CARGO_PKG_VERSION")) .about(env!("CARGO_PKG_DESCRIPTION")) .arg(Arg::with_name("local_port").short("l").long("local-port").takes_value(true).required(true) .help("The local port to which udpproxy should bind to")) .arg(Arg::with_name("remote_port").short("r").long("remote-port").takes_value(true).required(true) .help("The remote port to which UDP packets should be forwarded")) .arg(Arg::with_name("host").short("h").long("host").takes_value(true).required(true) .help("The remote address to which packets will be forwarded")) .arg(Arg::with_name("bind").short("b").long("bind").takes_value(true) .help("The address on which to listen for incoming requests")) .arg(Arg::with_name("debug").short("d").long("debug").takes_value(true) .help("Enable debug mode")) .arg(Arg::with_name("allowed_list").short("A").long("allowed-list").takes_value(true).multiple(true).help("Allowed IP list, e.g. 127.0.0.1, 120.0.0.0/8")) .get_matches(); let local_port: u16 = matches.value_of("local_port").unwrap().parse().unwrap(); let remote_port: u16 = matches.value_of("remote_port").unwrap().parse().unwrap(); let remote_host = matches.value_of("host").unwrap(); let bind_addr = match matches.value_of("bind") { Some(addr) => addr.to_owned(), None => LOCAL_ADDR.to_owned(), }; forward(&bind_addr, local_port, &remote_host, remote_port); } fn forward(bind_addr: &str, local_port: u16, remote_host: &str, remote_port: u16) { let local_addr = format!("{}:{}", bind_addr, local_port); debugging!("Listen address and port: {}", local_addr); let local = UdpSocket::bind(&local_addr).unwrap_or_else(|_| panic!("Unable to bind to {}", &local_addr)); information!("Listening on {}", local.local_addr().unwrap()); let remote_addr = format!("{}:{}", remote_host, remote_port); let responder = local.try_clone().unwrap_or_else( |_| panic!("Failed to clone primary listening address socket {}", local.local_addr().unwrap()) ); let (main_sender, main_receiver) = channel::<(_, Vec)>(); thread::spawn(move || { debugging!("Started new thread to deal out responses to clients"); loop { let (dest, buf) = main_receiver.recv().unwrap(); let to_send = buf.as_slice(); responder.send_to(to_send, dest).unwrap_or_else( |_| panic!("Failed to forward response from upstream server to client {}", dest) ); } }); let mut client_map = HashMap::new(); let mut buf = [0; 64 * 1024]; loop { let (num_bytes, src_addr) = local.recv_from(&mut buf).expect("Didn't receive data"); // TODO check src_addr ... //we create a new thread for each unique client let mut remove_existing = false; loop { debugging!("Received packet from client {}", src_addr); let mut ignore_failure = true; let client_id = format!("{}", src_addr); if remove_existing { debugging!("Removing existing forwarder from map: {}", client_id); client_map.remove(&client_id); } let sender = client_map.entry(client_id.clone()).or_insert_with(|| { //we are creating a new listener now, so a failure to send shoud be treated as an error ignore_failure = false; let local_send_queue = main_sender.clone(); let (sender, receiver) = channel::>(); let remote_addr_copy = remote_addr.clone(); thread::spawn(move|| { //regardless of which port we are listening to, we don't know which interface or IP //address the remote server is reachable via, so we bind the outgoing //connection to 0.0.0.0 in all cases. let temp_outgoing_addr = format!("0.0.0.0:{}", 1024 + rand::random::()); debugging!("Establishing new forwarder for client {} on {}", src_addr, &temp_outgoing_addr); let upstream_send = UdpSocket::bind(&temp_outgoing_addr).unwrap_or_else( |_| panic!("Failed to bind to transient address {}", &temp_outgoing_addr) ); let upstream_recv = upstream_send.try_clone().unwrap_or_else( |_| panic!("Failed to clone client-specific connection to upstream!") ); let mut timeouts: u64 = 0; let timed_out = Arc::new(AtomicBool::new(false)); let local_timed_out = timed_out.clone(); thread::spawn(move|| { let mut from_upstream = [0; 64 * 1024]; upstream_recv.set_read_timeout(Some(Duration::from_millis(TIMEOUT + 100))).unwrap(); loop { match upstream_recv.recv_from(&mut from_upstream) { Ok((bytes_rcvd, _)) => { let to_send = from_upstream[..bytes_rcvd].to_vec(); local_send_queue.send((src_addr, to_send)).expect("Failed to queue response from upstream server for forwarding!"); }, Err(_) => { if local_timed_out.load(Ordering::Relaxed) { debugging!("Terminating forwarder thread for client {} due to timeout", src_addr); break; } }, }; } }); loop { match receiver.recv_timeout(Duration::from_millis(TIMEOUT)) { Ok(from_client) => { upstream_send.send_to(from_client.as_slice(), &remote_addr_copy).unwrap_or_else( |_| panic!("Failed to forward packet from client {} to upstream server!", src_addr) ); timeouts = 0; //reset timeout count }, Err(_) => { timeouts += 1; if timeouts >= 10 { debugging!("Disconnecting forwarder for client {} due to timeout", src_addr); timed_out.store(true, Ordering::Relaxed); break; } }, }; } }); sender }); let to_send = buf[..num_bytes].to_vec(); match sender.send(to_send) { Ok(_) => break, Err(_) => { if !ignore_failure { panic!("Failed to send message to datagram forwarder for client {}", client_id); } //client previously timed out debugging!("New connection received from previously timed-out client {}", client_id); remove_existing = true; continue; }, } } } }