diff --git a/Cargo.lock b/Cargo.lock index aa18ac2..bbee0f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,9 +232,9 @@ dependencies = [ [[package]] name = "rust_util" -version = "0.6.12" +version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c261320f663e65d0869f77036c454f627e82eecbd574946e6eb55a97d3dc490c" +checksum = "190763556e2faed2ba3120c04e3051001d84dbe556d397f030def64c1ae27fd2" dependencies = [ "lazy_static", "libc", diff --git a/src/main.rs b/src/main.rs index 0b6721c..1bd7046 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use std::sync::mpsc::channel; use std::thread; use std::time::Duration; use clap::{ Arg, App }; +use rust_util::util_net::IpAddressMaskGroup; const LOCAL_ADDR: &str = "127.0.0.1"; const TIMEOUT: u64 = 3 * 60 * 100; //3 minutes @@ -39,11 +40,20 @@ fn main() { Some(addr) => addr.to_owned(), None => LOCAL_ADDR.to_owned(), }; + let allowed_ip_address_mask_list = IpAddressMaskGroup::parse( + &matches.values_of("allowed_list") + .map(|l| l.map(|i| i.to_owned()).collect::>()) + .unwrap_or_else(|| vec![]) + ); + debugging!("Allowed ip address mask list count: {}", allowed_ip_address_mask_list.ip_address_mask_group.len()); + allowed_ip_address_mask_list.ip_address_mask_group.iter().for_each(|ip| { + debugging!("- {}", ip); + }); - forward(&bind_addr, local_port, &remote_host, remote_port); + forward(&bind_addr, local_port, &remote_host, remote_port, &allowed_ip_address_mask_list); } -fn forward(bind_addr: &str, local_port: u16, remote_host: &str, remote_port: u16) { +fn forward(bind_addr: &str, local_port: u16, remote_host: &str, remote_port: u16, allowed_ip_address_mask_list: &IpAddressMaskGroup) { let local_addr = format!("{}:{}", bind_addr, local_port); debugging!("Listen address and port: {}", local_addr); let local = UdpSocket::bind(&local_addr).unwrap_or_else(|_| panic!("Unable to bind to {}", &local_addr)); @@ -70,7 +80,10 @@ fn forward(bind_addr: &str, local_port: u16, remote_host: &str, remote_port: u16 let mut buf = [0; 64 * 1024]; loop { let (num_bytes, src_addr) = local.recv_from(&mut buf).expect("Didn't receive data"); - // TODO check src_addr ... + if !allowed_ip_address_mask_list.is_empty_or_matches(&src_addr) { + information!("Banned source address: {}", src_addr); + continue; + } //we create a new thread for each unique client let mut remove_existing = false;