182 lines
8.5 KiB
Rust
182 lines
8.5 KiB
Rust
extern crate rand;
|
|
#[macro_use]
|
|
extern crate rust_util;
|
|
|
|
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::net::UdpSocket;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{ AtomicBool, Ordering };
|
|
use std::sync::mpsc::channel;
|
|
use std::thread;
|
|
use std::time::Duration;
|
|
use clap::{ Arg, App };
|
|
use rust_util::util_net::IpAddressMaskGroup;
|
|
|
|
const LOCAL_ADDR: &str = "127.0.0.1";
|
|
const TIMEOUT: u64 = 3 * 60 * 100; //3 minutes
|
|
|
|
fn main() {
|
|
let matches = App::new("Simple Rust UDP Proxy")
|
|
.version(env!("CARGO_PKG_VERSION"))
|
|
.about(env!("CARGO_PKG_DESCRIPTION"))
|
|
.arg(Arg::with_name("local_port").short("l").long("local-port").takes_value(true).required(true)
|
|
.help("The local port to which udpproxy should bind to"))
|
|
.arg(Arg::with_name("remote_port").short("r").long("remote-port").takes_value(true).required(true)
|
|
.help("The remote port to which UDP packets should be forwarded"))
|
|
.arg(Arg::with_name("host").short("h").long("host").takes_value(true).required(true)
|
|
.help("The remote address to which packets will be forwarded"))
|
|
.arg(Arg::with_name("bind").short("b").long("bind").takes_value(true)
|
|
.help("The address on which to listen for incoming requests"))
|
|
.arg(Arg::with_name("debug").short("d").long("debug").takes_value(true)
|
|
.help("Enable debug mode"))
|
|
.arg(Arg::with_name("allowed_list").short("A").long("allowed-list").takes_value(true).multiple(true).help("Allowed IP list, e.g. 127.0.0.1, 120.0.0.0/8"))
|
|
.get_matches();
|
|
|
|
let local_port: u16 = matches.value_of("local_port").unwrap().parse().unwrap();
|
|
let remote_port: u16 = matches.value_of("remote_port").unwrap().parse().unwrap();
|
|
let remote_host = matches.value_of("host").unwrap();
|
|
let bind_addr = match matches.value_of("bind") {
|
|
Some(addr) => addr.to_owned(),
|
|
None => LOCAL_ADDR.to_owned(),
|
|
};
|
|
let allowed_ip_address_mask_list = IpAddressMaskGroup::parse(
|
|
&matches.values_of("allowed_list")
|
|
.map(|l| l.map(|i| i.to_owned()).collect::<Vec<_>>())
|
|
.unwrap_or_else(|| vec![])
|
|
);
|
|
debugging!("Allowed ip address mask list count: {}", allowed_ip_address_mask_list.ip_address_mask_group.len());
|
|
allowed_ip_address_mask_list.ip_address_mask_group.iter().for_each(|ip| {
|
|
debugging!("- {}", ip);
|
|
});
|
|
|
|
forward(&bind_addr, local_port, &remote_host, remote_port, &allowed_ip_address_mask_list);
|
|
}
|
|
|
|
fn forward(bind_addr: &str, local_port: u16, remote_host: &str, remote_port: u16, allowed_ip_address_mask_list: &IpAddressMaskGroup) {
|
|
let local_addr = format!("{}:{}", bind_addr, local_port);
|
|
debugging!("Listen address and port: {}", local_addr);
|
|
let local = UdpSocket::bind(&local_addr).unwrap_or_else(|_| panic!("Unable to bind to {}", &local_addr));
|
|
information!("Listening on {}", local.local_addr().unwrap());
|
|
|
|
let remote_addr = format!("{}:{}", remote_host, remote_port);
|
|
|
|
let responder = local.try_clone().unwrap_or_else(
|
|
|_| panic!("Failed to clone primary listening address socket {}", local.local_addr().unwrap())
|
|
);
|
|
let (main_sender, main_receiver) = channel::<(_, Vec<u8>)>();
|
|
thread::spawn(move || {
|
|
debugging!("Started new thread to deal out responses to clients");
|
|
loop {
|
|
let (dest, buf) = main_receiver.recv().unwrap();
|
|
let to_send = buf.as_slice();
|
|
responder.send_to(to_send, dest).unwrap_or_else(
|
|
|_| panic!("Failed to forward response from upstream server to client {}", dest)
|
|
);
|
|
}
|
|
});
|
|
|
|
let mut client_map = HashMap::new();
|
|
let mut buf = [0; 64 * 1024];
|
|
loop {
|
|
let (num_bytes, src_addr) = local.recv_from(&mut buf).expect("Didn't receive data");
|
|
if !allowed_ip_address_mask_list.is_empty_or_matches(&src_addr) {
|
|
information!("Banned source address: {}", src_addr);
|
|
continue;
|
|
}
|
|
|
|
//we create a new thread for each unique client
|
|
let mut remove_existing = false;
|
|
loop {
|
|
debugging!("Received packet from client {}", src_addr);
|
|
|
|
let mut ignore_failure = true;
|
|
let client_id = format!("{}", src_addr);
|
|
|
|
if remove_existing {
|
|
debugging!("Removing existing forwarder from map: {}", client_id);
|
|
client_map.remove(&client_id);
|
|
}
|
|
|
|
let sender = client_map.entry(client_id.clone()).or_insert_with(|| {
|
|
//we are creating a new listener now, so a failure to send shoud be treated as an error
|
|
ignore_failure = false;
|
|
|
|
let local_send_queue = main_sender.clone();
|
|
let (sender, receiver) = channel::<Vec<u8>>();
|
|
let remote_addr_copy = remote_addr.clone();
|
|
thread::spawn(move|| {
|
|
//regardless of which port we are listening to, we don't know which interface or IP
|
|
//address the remote server is reachable via, so we bind the outgoing
|
|
//connection to 0.0.0.0 in all cases.
|
|
let temp_outgoing_addr = format!("0.0.0.0:{}", 1024 + rand::random::<u16>());
|
|
debugging!("Establishing new forwarder for client {} on {}", src_addr, &temp_outgoing_addr);
|
|
let upstream_send = UdpSocket::bind(&temp_outgoing_addr).unwrap_or_else(
|
|
|_| panic!("Failed to bind to transient address {}", &temp_outgoing_addr)
|
|
);
|
|
let upstream_recv = upstream_send.try_clone().unwrap_or_else(
|
|
|_| panic!("Failed to clone client-specific connection to upstream!")
|
|
);
|
|
|
|
let mut timeouts: u64 = 0;
|
|
let timed_out = Arc::new(AtomicBool::new(false));
|
|
|
|
let local_timed_out = timed_out.clone();
|
|
thread::spawn(move|| {
|
|
let mut from_upstream = [0; 64 * 1024];
|
|
upstream_recv.set_read_timeout(Some(Duration::from_millis(TIMEOUT + 100))).unwrap();
|
|
loop {
|
|
match upstream_recv.recv_from(&mut from_upstream) {
|
|
Ok((bytes_rcvd, _)) => {
|
|
let to_send = from_upstream[..bytes_rcvd].to_vec();
|
|
local_send_queue.send((src_addr, to_send)).expect("Failed to queue response from upstream server for forwarding!");
|
|
},
|
|
Err(_) => {
|
|
if local_timed_out.load(Ordering::Relaxed) {
|
|
debugging!("Terminating forwarder thread for client {} due to timeout", src_addr);
|
|
break;
|
|
}
|
|
},
|
|
};
|
|
}
|
|
});
|
|
|
|
loop {
|
|
match receiver.recv_timeout(Duration::from_millis(TIMEOUT)) {
|
|
Ok(from_client) => {
|
|
upstream_send.send_to(from_client.as_slice(), &remote_addr_copy).unwrap_or_else(
|
|
|_| panic!("Failed to forward packet from client {} to upstream server!", src_addr)
|
|
);
|
|
timeouts = 0; //reset timeout count
|
|
},
|
|
Err(_) => {
|
|
timeouts += 1;
|
|
if timeouts >= 10 {
|
|
debugging!("Disconnecting forwarder for client {} due to timeout", src_addr);
|
|
timed_out.store(true, Ordering::Relaxed);
|
|
break;
|
|
}
|
|
},
|
|
};
|
|
}
|
|
});
|
|
sender
|
|
});
|
|
|
|
let to_send = buf[..num_bytes].to_vec();
|
|
match sender.send(to_send) {
|
|
Ok(_) => break,
|
|
Err(_) => {
|
|
if !ignore_failure {
|
|
panic!("Failed to send message to datagram forwarder for client {}", client_id);
|
|
}
|
|
//client previously timed out
|
|
debugging!("New connection received from previously timed-out client {}", client_id);
|
|
remove_existing = true;
|
|
continue;
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|