optimize code structure

This commit is contained in:
wyhaya
2019-08-24 16:36:11 +08:00
parent 7848511160
commit 0c8c84386a
3 changed files with 132 additions and 110 deletions

View File

@@ -17,29 +17,42 @@ lazy_static! {
} }
fn cap_socket_addr(reg: &Regex, text: &str) -> Option<Result<SocketAddr, InvalidType>> { fn cap_socket_addr(reg: &Regex, text: &str) -> Option<Result<SocketAddr, InvalidType>> {
if let Some(cap) = reg.captures(text) { let cap = match reg.captures(text) {
return match cap.name("val") { Some(cap) => cap,
None => return None,
};
match cap.name("val") {
Some(m) => match m.as_str().parse() { Some(m) => match m.as_str().parse() {
Ok(addr) => Some(Ok(addr)), Ok(addr) => Some(Ok(addr)),
Err(_) => Some(Err(InvalidType::SocketAddr)), Err(_) => Some(Err(InvalidType::SocketAddr)),
}, },
None => Some(Err(InvalidType::SocketAddr)), None => Some(Err(InvalidType::SocketAddr)),
}
}
fn cap_ip_addr(text: &str) -> Option<Result<(Regex, IpAddr), InvalidType>> {
let cap = match (&REG_DOMAIN_IP as &Regex).captures(text) {
Some(cap) => cap,
None => return None,
}; };
}
None
}
fn cap_ip_addr(reg: &Regex, text: &str) -> Option<Result<(Regex, IpAddr), InvalidType>> { let (val1, val2) = match (cap.name("val1"), cap.name("val2")) {
if let Some(cap) = reg.captures(text) { (Some(val1), Some(val2)) => (val1.as_str(), val2.as_str()),
if let (Some(val1), Some(val2)) = (cap.name("val1"), cap.name("val2")) { _ => {
let (val1, val2) = (val1.as_str(), val2.as_str()); return Some(Err(InvalidType::Other));
}
};
// ip domain
if let Ok(ip) = val1.parse() { if let Ok(ip) = val1.parse() {
return match Regex::new(val2) { return match Regex::new(val2) {
Ok(reg) => Some(Ok((reg, ip))), Ok(reg) => Some(Ok((reg, ip))),
Err(_) => Some(Err(InvalidType::Regex)), Err(_) => Some(Err(InvalidType::Regex)),
}; };
} else { }
// domain ip
let ip = match val2.parse() { let ip = match val2.parse() {
Ok(ip) => ip, Ok(ip) => ip,
Err(_) => return Some(Err(InvalidType::IpAddr)), Err(_) => return Some(Err(InvalidType::IpAddr)),
@@ -52,12 +65,6 @@ fn cap_ip_addr(reg: &Regex, text: &str) -> Option<Result<(Regex, IpAddr), Invali
return Some(Ok((reg, ip))); return Some(Ok((reg, ip)));
} }
}
return Some(Err(InvalidType::Other));
}
None
}
#[derive(Debug)] #[derive(Debug)]
pub struct Config { pub struct Config {
@@ -71,6 +78,7 @@ pub struct Invalid {
pub source: String, pub source: String,
pub err: InvalidType, pub err: InvalidType,
} }
#[derive(Debug)] #[derive(Debug)]
pub enum InvalidType { pub enum InvalidType {
Regex, Regex,
@@ -102,8 +110,10 @@ impl Config {
} }
pub fn parse(&mut self) -> io::Result<(Vec<SocketAddr>, Vec<SocketAddr>, Hosts, Vec<Invalid>)> { pub fn parse(&mut self) -> io::Result<(Vec<SocketAddr>, Vec<SocketAddr>, Hosts, Vec<Invalid>)> {
let (mut hosts, mut binds, mut proxys, mut errors) = let mut hosts = Hosts::new();
(Hosts::new(), Vec::new(), Vec::new(), Vec::new()); let mut binds = Vec::new();
let mut proxy = Vec::new();
let mut errors = Vec::new();
for (n, line) in self.content.lines().enumerate() { for (n, line) in self.content.lines().enumerate() {
// ignore // ignore
@@ -129,7 +139,7 @@ impl Config {
// proxy // proxy
if let Some(addr) = cap_socket_addr(&REG_PROXY, &line) { if let Some(addr) = cap_socket_addr(&REG_PROXY, &line) {
match addr { match addr {
Ok(addr) => proxys.push(addr), Ok(addr) => proxy.push(addr),
Err(err) => { Err(err) => {
errors.push(Invalid { errors.push(Invalid {
line: n + 1, line: n + 1,
@@ -146,7 +156,7 @@ impl Config {
if let Some(m) = cap.name("val") { if let Some(m) = cap.name("val") {
let (b, p, h, e) = Config::new(m.as_str())?.parse()?; let (b, p, h, e) = Config::new(m.as_str())?.parse()?;
binds.extend(b); binds.extend(b);
proxys.extend(p); proxy.extend(p);
hosts.extend(h); hosts.extend(h);
errors.extend(e); errors.extend(e);
} else { } else {
@@ -156,7 +166,7 @@ impl Config {
} }
// host // host
if let Some(d) = cap_ip_addr(&REG_DOMAIN_IP, &line) { if let Some(d) = cap_ip_addr(&line) {
match d { match d {
Ok((domain, ip)) => hosts.push(domain, ip), Ok((domain, ip)) => hosts.push(domain, ip),
Err(err) => { Err(err) => {
@@ -177,7 +187,7 @@ impl Config {
}); });
} }
Ok((binds, proxys, hosts, errors)) Ok((binds, proxy, hosts, errors))
} }
} }

View File

@@ -3,9 +3,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind, Result};
use std::io::Result;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
pub struct BytePacketBuffer { pub struct BytePacketBuffer {

View File

@@ -14,6 +14,7 @@ use async_std::task;
use config::{Config, Hosts, Invalid, InvalidType}; use config::{Config, Hosts, Invalid, InvalidType};
use dirs; use dirs;
use lib::*; use lib::*;
use regex::Regex;
use std::env; use std::env;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
@@ -90,6 +91,18 @@ fn main() {
if values.len() != 2 { if values.len() != 2 {
exit!("'add' value: [DOMAIN] [IP]"); exit!("'add' value: [DOMAIN] [IP]");
} }
if let Err(err) = Regex::new(values[0]) {
exit!(
"Cannot resolve '{}' to regular expression\n{:?}",
values[0],
err
);
}
if let Err(_) = values[1].parse::<IpAddr>() {
exit!("Cannot resolve '{}' to ip address", values[1]);
}
let mut config = match Config::new(&config_path) { let mut config = match Config::new(&config_path) {
Ok(c) => c, Ok(c) => c,
Err(err) => exit!("Failed to read config file: {:?}\n{:?}", &config_path, err), Err(err) => exit!("Failed to read config file: {:?}\n{:?}", &config_path, err),
@@ -103,6 +116,7 @@ fn main() {
if value.is_empty() { if value.is_empty() {
exit!("'rm' value: [DOMAIN | IP]"); exit!("'rm' value: [DOMAIN | IP]");
} }
log!("todo");
} }
} }
"ls" => { "ls" => {
@@ -114,7 +128,7 @@ fn main() {
} }
} }
for (domain, ip) in hosts.iter() { for (domain, ip) in hosts.iter() {
println!("{:domain$} {}", domain.as_str(), ip, domain = n); log!("{:domain$} {}", domain.as_str(), ip, domain = n);
} }
} }
"config" => { "config" => {
@@ -135,7 +149,7 @@ fn main() {
Ok(p) => p.display().to_string(), Ok(p) => p.display().to_string(),
Err(err) => exit!("Failed to get directory\n{:?}", err), Err(err) => exit!("Failed to get directory\n{:?}", err),
}; };
println!("Binary: {}\nConfig: {:?}", binary, config_path); log!("Binary: {}\nConfig: {:?}", binary, config_path);
} }
"help" => { "help" => {
app.help(); app.help();
@@ -150,21 +164,39 @@ fn main() {
return; return;
} }
let (_, mut binds, proxys, hosts) = config_parse(&config_path); let (_, mut binds, proxy, hosts) = config_parse(&config_path);
if binds.is_empty() { if binds.is_empty() {
warn!("Will bind the default address '{}'", DEFAULT_BIND); warn!("Will bind the default address '{}'", DEFAULT_BIND);
binds.push(DEFAULT_BIND.parse().unwrap()); binds.push(DEFAULT_BIND.parse().unwrap());
} }
if proxys.is_empty() { if proxy.is_empty() {
warn!( warn!(
"Will use the default proxy address '{}'", "Will use the default proxy address '{}'",
DEFAULT_PROXY.join(", ") DEFAULT_PROXY.join(", ")
); );
} }
update_config(proxys, hosts);
task::spawn(watch_config(config_path)); update_config(proxy, hosts);
task::block_on(run_server(binds));
// Run server
for addr in binds {
task::spawn(run_server(addr.clone()));
}
// watch config
task::block_on(watch_config(config_path));
}
fn update_config(mut proxy: Vec<SocketAddr>, hosts: Hosts) {
if proxy.is_empty() {
proxy = DEFAULT_PROXY
.iter()
.map(|p| p.parse().unwrap())
.collect::<Vec<SocketAddr>>();
}
unsafe {
PROXY = proxy;
HOSTS = Some(hosts);
};
} }
fn config_parse(file: &PathBuf) -> (Config, Vec<SocketAddr>, Vec<SocketAddr>, Hosts) { fn config_parse(file: &PathBuf) -> (Config, Vec<SocketAddr>, Vec<SocketAddr>, Hosts) {
@@ -173,13 +205,13 @@ fn config_parse(file: &PathBuf) -> (Config, Vec<SocketAddr>, Vec<SocketAddr>, Ho
Err(err) => exit!("Failed to read config file: {:?}\n{:?}", file, err), Err(err) => exit!("Failed to read config file: {:?}\n{:?}", file, err),
}; };
let (binds, proxys, hosts, errors) = match config.parse() { let (binds, proxy, hosts, errors) = match config.parse() {
Ok(d) => d, Ok(d) => d,
Err(err) => exit!("Parsing config file failed\n{:?}", err), Err(err) => exit!("Parsing config file failed\n{:?}", err),
}; };
output_invalid(errors); output_invalid(errors);
(config, binds, proxys, hosts) (config, binds, proxy, hosts)
} }
fn output_invalid(errors: Vec<Invalid>) { fn output_invalid(errors: Vec<Invalid>) {
@@ -212,23 +244,7 @@ async fn watch_config(p: PathBuf) {
.await; .await;
} }
fn update_config(mut proxy: Vec<SocketAddr>, hosts: Hosts) { async fn run_server(addr: SocketAddr) {
if proxy.is_empty() {
proxy = DEFAULT_PROXY
.iter()
.map(|p| p.parse().unwrap())
.collect::<Vec<SocketAddr>>();
}
unsafe {
PROXY = proxy;
HOSTS = Some(hosts);
};
}
async fn run_server(binds: Vec<SocketAddr>) {
let mut tasks = vec![];
for addr in binds {
let task = task::spawn(async move {
let socket = match UdpSocket::bind(&addr).await { let socket = match UdpSocket::bind(&addr).await {
Ok(socket) => { Ok(socket) => {
log!("Start listening to '{}'", addr); log!("Start listening to '{}'", addr);
@@ -236,10 +252,16 @@ async fn run_server(binds: Vec<SocketAddr>) {
} }
Err(err) => exit!("Binding '{}' failed\n{:?}", addr, err), Err(err) => exit!("Binding '{}' failed\n{:?}", addr, err),
}; };
loop { loop {
let mut req = BytePacketBuffer::new(); let mut req = BytePacketBuffer::new();
match socket.recv_from(&mut req.buf).await { let (len, src) = match socket.recv_from(&mut req.buf).await {
Ok((len, src)) => { Ok(r) => r,
Err(err) => {
error!("Failed to receive message\n{:?}", err);
continue;
}
};
let res = match handle(req, len).await { let res = match handle(req, len).await {
Ok(data) => data, Ok(data) => data,
Err(err) => { Err(err) => {
@@ -251,17 +273,6 @@ async fn run_server(binds: Vec<SocketAddr>) {
error!("Replying to '{}' failed\n{:?}", &src, err); error!("Replying to '{}' failed\n{:?}", &src, err);
} }
} }
Err(err) => {
error!("Failed to receive message\n{:?}", err);
}
}
}
});
tasks.push(task);
}
for task in tasks {
task.await;
}
} }
async fn proxy(buf: &[u8]) -> io::Result<Vec<u8>> { async fn proxy(buf: &[u8]) -> io::Result<Vec<u8>> {
@@ -332,17 +343,20 @@ async fn handle(mut req: BytePacketBuffer, len: usize) -> io::Result<Vec<u8>> {
log!("Query: {} Type: {:?}", query.name, query.qtype); log!("Query: {} Type: {:?}", query.name, query.qtype);
if let Some(answer) = get_answer(&query.name, query.qtype) { // Whether to proxy
let answer = match get_answer(&query.name, query.qtype) {
Some(a) => a,
None => return proxy(&req.buf[..len]).await,
};
request.header.recursion_desired = true; request.header.recursion_desired = true;
request.header.recursion_available = true; request.header.recursion_available = true;
request.header.response = true; request.header.response = true;
request.answers.push(answer); request.answers.push(answer);
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();
request.write(&mut res_buffer)?; request.write(&mut res_buffer)?;
let len = res_buffer.pos(); let len = res_buffer.pos();
let data = res_buffer.get_range(0, len)?; let data = res_buffer.get_range(0, len)?;
Ok(data.to_vec()) Ok(data.to_vec())
} else {
proxy(&req.buf[..len]).await
}
} }