#[macro_use] extern crate lazy_static; mod config; mod lib; mod watch; use ace::App; use config::{Config, Hosts, Invalid, ParseConfig}; use dirs; use lib::*; use regex::Regex; use std::env::current_exe; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::process::Command; use std::time::Duration; use tokio::io::{Error, ErrorKind, Result}; use tokio::net::UdpSocket; use tokio::prelude::*; use tokio::timer::Timeout; use watch::Watch; const CONFIG_FILE: [&'static str; 2] = [".updns", "config"]; const DEFAULT_BIND: &'static str = "0.0.0.0:53"; const DEFAULT_PROXY: [&'static str; 2] = ["8.8.8.8:53", "1.1.1.1:53"]; const DEFAULT_TIMEOUT: u64 = 2000; static mut PROXY: Vec = Vec::new(); static mut HOSTS: Option = None; static mut TIMEOUT: u64 = DEFAULT_TIMEOUT; const WATCH_INTERVAL: u64 = 5000; macro_rules! exit { ($($arg:tt)*) => { { eprintln!($($arg)*); std::process::exit(1) } }; } macro_rules! error { ($($arg:tt)*) => { eprint!("{} ERROR ", time::now().strftime("[%Y-%m-%d %H:%M:%S]").unwrap()); eprintln!($($arg)*); }; } macro_rules! info { ($($arg:tt)*) => { print!("{} INFO ", time::now().strftime("[%Y-%m-%d %H:%M:%S]").unwrap()); println!($($arg)*); }; } macro_rules! warn { ($($arg:tt)*) => { print!("{} WARN ", time::now().strftime("[%Y-%m-%d %H:%M:%S]").unwrap()); println!($($arg)*); }; } #[tokio::main] async fn main() { let app = App::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")) .cmd("add", "Add a DNS record") .cmd("ls", "Print all configured DNS records") .cmd("config", "Call 'vim' to edit the configuration file") .cmd("path", "Print related directories") .cmd("help", "Print help information") .cmd("version", "Print version information") .opt("-c", "Specify a config file") .opt("-w", "Check the interval of the configuration file (ms)"); let config_path = match app.value("-c") { Some(values) => { if values.is_empty() { exit!("'-c' value: [CONFIG]"); } PathBuf::from(values[0]) } None => match dirs::home_dir() { Some(p) => p.join(CONFIG_FILE[0]).join(CONFIG_FILE[1]), None => exit!("Can't get home directory"), }, }; // Check profile interval let watch_interval = match app.value("-w") { Some(values) => { if values.is_empty() { exit!("'-w' value: [ms]"); } match &values[0].parse::() { Ok(t) => *t, Err(_) => exit!("Cannot resolve '{}' to number", &values[0]), } } None => WATCH_INTERVAL, }; if let Some(cmd) = app.command() { match cmd.as_str() { "add" => { let values = app.value("add").unwrap_or(vec![]); if values.len() != 2 { exit!("'add' value: [DOMAIN] [IP]"); } if let Err(err) = Regex::new(values[0]) { exit!( "Cannot resolve '{}' to regular expression\n{:?}", values[0], err ); } if let Err(_) = values[1].parse::() { exit!("Cannot resolve '{}' to ip address", values[1]); } let mut config = match Config::new(&config_path).await { Ok(c) => c, Err(err) => exit!("Failed to read config file {:?}\n{:?}", &config_path, err), }; if let Err(err) = config.add(&values[0], &values[1]).await { exit!("Add record failed\n{:?}", err); } } "ls" => { let mut config = config_parse(&config_path).await; let mut n = 0; for (reg, _) in config.hosts.iter() { if reg.as_str().len() > n { n = reg.as_str().len(); } } for (domain, ip) in config.hosts.iter() { println!("{:domain$} {}", domain.as_str(), ip, domain = n); } } "config" => { let cmd = Command::new("vim").arg(&config_path).status(); match cmd { Ok(status) => { if status.success() { config_parse(&config_path).await; } else { println!("'vim' exits with a non-zero status code: {:?}", status); } } Err(err) => exit!("Call 'vim' command failed\n{:?}", err), } } "path" => { let binary = match current_exe() { Ok(p) => p.display().to_string(), Err(err) => exit!("Failed to get directory\n{:?}", err), }; println!( "Binary: {}\nConfig: {}", binary, config_path.to_string_lossy() ); } "help" => app.help(), "version" => app.version(), _ => app.error_try("help"), } return; } let mut parse = config_parse(&config_path).await; if parse.bind.is_empty() { warn!("Will bind the default address '{}'", DEFAULT_BIND); parse.bind.push(DEFAULT_BIND.parse().unwrap()); } if parse.proxy.is_empty() { warn!( "Will use the default proxy address '{}'", DEFAULT_PROXY.join(", ") ); } update_config(parse.proxy, parse.hosts, parse.timeout); // Run server for addr in parse.bind { tokio::spawn(run_server(addr.clone())); } // watch config watch_config(config_path, watch_interval).await; } fn update_config(mut proxy: Vec, hosts: Hosts, timeout: Option) { if proxy.is_empty() { proxy = DEFAULT_PROXY .iter() .map(|p| p.parse().unwrap()) .collect::>(); } unsafe { PROXY = proxy; HOSTS = Some(hosts); TIMEOUT = match timeout { Some(t) => t, None => DEFAULT_TIMEOUT, }; }; } async fn config_parse(file: &PathBuf) -> ParseConfig { let config = match Config::new(file).await { Ok(c) => c, Err(err) => exit!("Failed to read config file {:?}\n{:?}", file, err), }; let parse: ParseConfig = match config.parse().await { Ok(d) => d, Err(err) => exit!("Parsing config file failed\n{:?}", err), }; output_invalid(&parse.invalid); parse } fn output_invalid(errors: &Vec) { for invalid in errors { error!( "[line:{}] {} `{}`", invalid.line, invalid.kind.text(), invalid.source ); } } async fn watch_config(p: PathBuf, t: u64) { let mut watch = Watch::new(&p, t).await; while let Some(_) = watch.next().await { info!("Reload the configuration file: {:?}", &p); if let Ok(config) = Config::new(&p).await { if let Ok(parse) = config.parse().await { update_config(parse.proxy, parse.hosts, parse.timeout); output_invalid(&parse.invalid); } } } } async fn run_server(addr: SocketAddr) { let mut socket = match UdpSocket::bind(&addr).await { Ok(socket) => { info!("Start listening to '{}'", addr); socket } Err(err) => exit!("Binding '{}' failed\n{:?}", addr, err), }; loop { let mut req = BytePacketBuffer::new(); let (len, src) = match socket.recv_from(&mut req.buf).await { Ok(r) => r, Err(err) => { error!("Failed to receive message {:?}", err); continue; } }; let res = match handle(req, len).await { Ok(data) => data, Err(err) => { error!("Processing request failed {:?}", err); continue; } }; if let Err(err) = socket.send_to(&res, &src).await { error!("Replying to '{}' failed {:?}", &src, err); } } } async fn proxy(buf: &[u8]) -> Result> { let proxy = unsafe { &PROXY }; for addr in proxy.iter() { let mut socket = UdpSocket::bind(("0.0.0.0", 0)).await?; let data: Result> = Timeout::new( async { socket.send_to(&buf, addr).await?; let mut res = [0; 512]; let len = socket.recv(&mut res).await?; Ok(res[..len].to_vec()) }, Duration::from_millis(unsafe { TIMEOUT }), ) .await?; match data { Ok(data) => { return Ok(data); } Err(err) => { error!("Agent request to {} {:?}", addr, err); } } } Err(Error::new( ErrorKind::Other, "Proxy server failed to proxy request", )) } fn get_answer(domain: &str, query: QueryType) -> Option { let hosts = unsafe { HOSTS.as_ref().unwrap() }; if let Some(ip) = hosts.get(domain) { match query { QueryType::A => { if let IpAddr::V4(addr) = ip { return Some(DnsRecord::A { domain: domain.to_string(), addr: addr.clone(), ttl: 3600, }); } } QueryType::AAAA => { if let IpAddr::V6(addr) = ip { return Some(DnsRecord::AAAA { domain: domain.to_string(), addr: addr.clone(), ttl: 3600, }); } } _ => {} } } None } async fn handle(mut req: BytePacketBuffer, len: usize) -> Result> { let mut request = DnsPacket::from_buffer(&mut req)?; let query = match request.questions.get(0) { Some(q) => q, None => return proxy(&req.buf[..len]).await, }; info!("{} {:?}", query.name, query.qtype); // Whether to proxy let answer = match get_answer(&query.name, query.qtype) { Some(a) => a, None => return proxy(&req.buf[..len]).await, }; request.header.recursion_desired = true; request.header.recursion_available = true; request.header.response = true; request.answers.push(answer); let mut res_buffer = BytePacketBuffer::new(); request.write(&mut res_buffer)?; let len = res_buffer.pos(); let data = res_buffer.get_range(0, len)?; Ok(data.to_vec()) }