From 0c8c84386ae0b8359b3bb6a1e3f91c6ce5f1f317 Mon Sep 17 00:00:00 2001 From: wyhaya Date: Sat, 24 Aug 2019 16:36:11 +0800 Subject: [PATCH] optimize code structure --- src/config.rs | 88 ++++++++++++++++------------- src/lib.rs | 4 +- src/main.rs | 150 +++++++++++++++++++++++++++----------------------- 3 files changed, 132 insertions(+), 110 deletions(-) diff --git a/src/config.rs b/src/config.rs index 481dd77..72a7f73 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,46 +17,53 @@ lazy_static! { } fn cap_socket_addr(reg: &Regex, text: &str) -> Option> { - if let Some(cap) = reg.captures(text) { - return match cap.name("val") { - Some(m) => match m.as_str().parse() { - Ok(addr) => Some(Ok(addr)), - Err(_) => Some(Err(InvalidType::SocketAddr)), - }, - None => Some(Err(InvalidType::SocketAddr)), - }; + let cap = match reg.captures(text) { + Some(cap) => cap, + None => return None, + }; + + match cap.name("val") { + Some(m) => match m.as_str().parse() { + Ok(addr) => Some(Ok(addr)), + Err(_) => Some(Err(InvalidType::SocketAddr)), + }, + None => Some(Err(InvalidType::SocketAddr)), } - None } -fn cap_ip_addr(reg: &Regex, text: &str) -> Option> { - if let Some(cap) = reg.captures(text) { - if let (Some(val1), Some(val2)) = (cap.name("val1"), cap.name("val2")) { - let (val1, val2) = (val1.as_str(), val2.as_str()); +fn cap_ip_addr(text: &str) -> Option> { + let cap = match (®_DOMAIN_IP as &Regex).captures(text) { + Some(cap) => cap, + None => return None, + }; - if let Ok(ip) = val1.parse() { - return match Regex::new(val2) { - Ok(reg) => Some(Ok((reg, ip))), - Err(_) => Some(Err(InvalidType::Regex)), - }; - } else { - let ip = match val2.parse() { - Ok(ip) => ip, - Err(_) => return Some(Err(InvalidType::IpAddr)), - }; - - let reg = match Regex::new(val1) { - Ok(reg) => reg, - Err(_) => return Some(Err(InvalidType::Regex)), - }; - - return Some(Ok((reg, ip))); - } + let (val1, val2) = match (cap.name("val1"), cap.name("val2")) { + (Some(val1), Some(val2)) => (val1.as_str(), val2.as_str()), + _ => { + return Some(Err(InvalidType::Other)); } + }; - return Some(Err(InvalidType::Other)); + // ip domain + if let Ok(ip) = val1.parse() { + return match Regex::new(val2) { + Ok(reg) => Some(Ok((reg, ip))), + Err(_) => Some(Err(InvalidType::Regex)), + }; } - None + + // domain ip + let ip = match val2.parse() { + Ok(ip) => ip, + Err(_) => return Some(Err(InvalidType::IpAddr)), + }; + + let reg = match Regex::new(val1) { + Ok(reg) => reg, + Err(_) => return Some(Err(InvalidType::Regex)), + }; + + return Some(Ok((reg, ip))); } #[derive(Debug)] @@ -71,6 +78,7 @@ pub struct Invalid { pub source: String, pub err: InvalidType, } + #[derive(Debug)] pub enum InvalidType { Regex, @@ -102,8 +110,10 @@ impl Config { } pub fn parse(&mut self) -> io::Result<(Vec, Vec, Hosts, Vec)> { - let (mut hosts, mut binds, mut proxys, mut errors) = - (Hosts::new(), Vec::new(), Vec::new(), Vec::new()); + let mut hosts = Hosts::new(); + let mut binds = Vec::new(); + let mut proxy = Vec::new(); + let mut errors = Vec::new(); for (n, line) in self.content.lines().enumerate() { // ignore @@ -129,7 +139,7 @@ impl Config { // proxy if let Some(addr) = cap_socket_addr(®_PROXY, &line) { match addr { - Ok(addr) => proxys.push(addr), + Ok(addr) => proxy.push(addr), Err(err) => { errors.push(Invalid { line: n + 1, @@ -146,7 +156,7 @@ impl Config { if let Some(m) = cap.name("val") { let (b, p, h, e) = Config::new(m.as_str())?.parse()?; binds.extend(b); - proxys.extend(p); + proxy.extend(p); hosts.extend(h); errors.extend(e); } else { @@ -156,7 +166,7 @@ impl Config { } // host - if let Some(d) = cap_ip_addr(®_DOMAIN_IP, &line) { + if let Some(d) = cap_ip_addr(&line) { match d { Ok((domain, ip)) => hosts.push(domain, ip), Err(err) => { @@ -177,7 +187,7 @@ impl Config { }); } - Ok((binds, proxys, hosts, errors)) + Ok((binds, proxy, hosts, errors)) } } diff --git a/src/lib.rs b/src/lib.rs index 81f5b27..11a1201 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,9 +3,7 @@ #![allow(dead_code)] -use std::io::{Error, ErrorKind}; -use std::io::Result; - +use std::io::{Error, ErrorKind, Result}; use std::net::{Ipv4Addr, Ipv6Addr}; pub struct BytePacketBuffer { diff --git a/src/main.rs b/src/main.rs index 141f98e..0a5c8c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ use async_std::task; use config::{Config, Hosts, Invalid, InvalidType}; use dirs; use lib::*; +use regex::Regex; use std::env; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; @@ -90,6 +91,18 @@ fn main() { if values.len() != 2 { exit!("'add' value: [DOMAIN] [IP]"); } + + if let Err(err) = Regex::new(values[0]) { + exit!( + "Cannot resolve '{}' to regular expression\n{:?}", + values[0], + err + ); + } + if let Err(_) = values[1].parse::() { + exit!("Cannot resolve '{}' to ip address", values[1]); + } + let mut config = match Config::new(&config_path) { Ok(c) => c, Err(err) => exit!("Failed to read config file: {:?}\n{:?}", &config_path, err), @@ -103,6 +116,7 @@ fn main() { if value.is_empty() { exit!("'rm' value: [DOMAIN | IP]"); } + log!("todo"); } } "ls" => { @@ -114,7 +128,7 @@ fn main() { } } for (domain, ip) in hosts.iter() { - println!("{:domain$} {}", domain.as_str(), ip, domain = n); + log!("{:domain$} {}", domain.as_str(), ip, domain = n); } } "config" => { @@ -135,7 +149,7 @@ fn main() { Ok(p) => p.display().to_string(), Err(err) => exit!("Failed to get directory\n{:?}", err), }; - println!("Binary: {}\nConfig: {:?}", binary, config_path); + log!("Binary: {}\nConfig: {:?}", binary, config_path); } "help" => { app.help(); @@ -150,21 +164,39 @@ fn main() { return; } - let (_, mut binds, proxys, hosts) = config_parse(&config_path); + let (_, mut binds, proxy, hosts) = config_parse(&config_path); if binds.is_empty() { warn!("Will bind the default address '{}'", DEFAULT_BIND); binds.push(DEFAULT_BIND.parse().unwrap()); } - if proxys.is_empty() { + if proxy.is_empty() { warn!( "Will use the default proxy address '{}'", DEFAULT_PROXY.join(", ") ); } - update_config(proxys, hosts); - task::spawn(watch_config(config_path)); - task::block_on(run_server(binds)); + update_config(proxy, hosts); + + // Run server + for addr in binds { + task::spawn(run_server(addr.clone())); + } + // watch config + task::block_on(watch_config(config_path)); +} + +fn update_config(mut proxy: Vec, hosts: Hosts) { + if proxy.is_empty() { + proxy = DEFAULT_PROXY + .iter() + .map(|p| p.parse().unwrap()) + .collect::>(); + } + unsafe { + PROXY = proxy; + HOSTS = Some(hosts); + }; } fn config_parse(file: &PathBuf) -> (Config, Vec, Vec, Hosts) { @@ -173,13 +205,13 @@ fn config_parse(file: &PathBuf) -> (Config, Vec, Vec, Ho Err(err) => exit!("Failed to read config file: {:?}\n{:?}", file, err), }; - let (binds, proxys, hosts, errors) = match config.parse() { + let (binds, proxy, hosts, errors) = match config.parse() { Ok(d) => d, Err(err) => exit!("Parsing config file failed\n{:?}", err), }; output_invalid(errors); - (config, binds, proxys, hosts) + (config, binds, proxy, hosts) } fn output_invalid(errors: Vec) { @@ -212,55 +244,34 @@ async fn watch_config(p: PathBuf) { .await; } -fn update_config(mut proxy: Vec, hosts: Hosts) { - if proxy.is_empty() { - proxy = DEFAULT_PROXY - .iter() - .map(|p| p.parse().unwrap()) - .collect::>(); - } - unsafe { - PROXY = proxy; - HOSTS = Some(hosts); +async fn run_server(addr: SocketAddr) { + let socket = match UdpSocket::bind(&addr).await { + Ok(socket) => { + log!("Start listening to '{}'", addr); + socket + } + Err(err) => exit!("Binding '{}' failed\n{:?}", addr, err), }; -} -async fn run_server(binds: Vec) { - let mut tasks = vec![]; - for addr in binds { - let task = task::spawn(async move { - let socket = match UdpSocket::bind(&addr).await { - Ok(socket) => { - log!("Start listening to '{}'", addr); - socket - } - Err(err) => exit!("Binding '{}' failed\n{:?}", addr, err), - }; - loop { - let mut req = BytePacketBuffer::new(); - match socket.recv_from(&mut req.buf).await { - Ok((len, src)) => { - let res = match handle(req, len).await { - Ok(data) => data, - Err(err) => { - error!("Processing request failed\n{:?}", err); - continue; - } - }; - if let Err(err) = socket.send_to(&res, &src).await { - error!("Replying to '{}' failed\n{:?}", &src, err); - } - } - Err(err) => { - error!("Failed to receive message\n{:?}", err); - } - } + loop { + let mut req = BytePacketBuffer::new(); + let (len, src) = match socket.recv_from(&mut req.buf).await { + Ok(r) => r, + Err(err) => { + error!("Failed to receive message\n{:?}", err); + continue; } - }); - tasks.push(task); - } - for task in tasks { - task.await; + }; + let res = match handle(req, len).await { + Ok(data) => data, + Err(err) => { + error!("Processing request failed\n{:?}", err); + continue; + } + }; + if let Err(err) = socket.send_to(&res, &src).await { + error!("Replying to '{}' failed\n{:?}", &src, err); + } } } @@ -332,17 +343,20 @@ async fn handle(mut req: BytePacketBuffer, len: usize) -> io::Result> { log!("Query: {} Type: {:?}", query.name, query.qtype); - if let Some(answer) = get_answer(&query.name, query.qtype) { - request.header.recursion_desired = true; - request.header.recursion_available = true; - request.header.response = true; - request.answers.push(answer); - let mut res_buffer = BytePacketBuffer::new(); - request.write(&mut res_buffer)?; - let len = res_buffer.pos(); - let data = res_buffer.get_range(0, len)?; - Ok(data.to_vec()) - } else { - proxy(&req.buf[..len]).await - } + // Whether to proxy + let answer = match get_answer(&query.name, query.qtype) { + Some(a) => a, + None => return proxy(&req.buf[..len]).await, + }; + + request.header.recursion_desired = true; + request.header.recursion_available = true; + request.header.response = true; + request.answers.push(answer); + let mut res_buffer = BytePacketBuffer::new(); + request.write(&mut res_buffer)?; + + let len = res_buffer.pos(); + let data = res_buffer.get_range(0, len)?; + Ok(data.to_vec()) }