migrate to tokio

This commit is contained in:
wyhaya
2019-11-12 17:17:34 +08:00
parent 54c990725a
commit e168c99635
5 changed files with 659 additions and 374 deletions

View File

@@ -6,22 +6,22 @@ mod lib;
mod watch;
use ace::App;
use async_std::io;
use async_std::net::UdpSocket;
use async_std::task;
use config::{Config, Hosts, Invalid, ParseConfig};
use dirs;
use lib::*;
use regex::Regex;
use std::env;
use std::env::current_exe;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::process::Command;
use std::time::Duration;
use tokio::io::{Error, ErrorKind, Result};
use tokio::net::UdpSocket;
use tokio::prelude::*;
use tokio::timer::Timeout;
use watch::Watch;
const CONFIG_FILE: [&'static str; 2] = [".updns", "config"];
const CONFIG_COMMAND: &'static str = "vim";
const DEFAULT_BIND: &'static str = "0.0.0.0:53";
const DEFAULT_PROXY: [&'static str; 2] = ["8.8.8.8:53", "1.1.1.1:53"];
@@ -60,13 +60,12 @@ macro_rules! warn {
};
}
fn main() {
let cct = format!("Call '{}' to edit the configuration file", CONFIG_COMMAND);
#[tokio::main]
async fn main() {
let app = App::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
.cmd("add", "Add a DNS record")
.cmd("rm", "Remove a DNS record")
.cmd("ls", "Print all configured DNS records")
.cmd("config", cct.as_str())
.cmd("config", "Call 'vim' to edit the configuration file")
.cmd("path", "Print related directories")
.cmd("help", "Print help information")
.cmd("version", "Print version information")
@@ -119,27 +118,16 @@ fn main() {
exit!("Cannot resolve '{}' to ip address", values[1]);
}
let mut config = match Config::new(&config_path) {
let mut config = match Config::new(&config_path).await {
Ok(c) => c,
Err(err) => exit!("Failed to read config file {:?}\n{:?}", &config_path, err),
};
if let Err(err) = config.add(&values[0], &values[1]) {
if let Err(err) = config.add(&values[0], &values[1]).await {
exit!("Add record failed\n{:?}", err);
}
}
"rm" => {
if let Some(value) = app.value("rm") {
if value.is_empty() {
exit!("'rm' value: [DOMAIN | IP]");
}
match ask("Confirm delete? Y/N\n") {
Ok(_) => println!("todo"),
Err(err) => exit!("{:?}", err),
}
}
}
"ls" => {
let mut config = config_parse(&config_path);
let mut config = config_parse(&config_path).await;
let mut n = 0;
for (reg, _) in config.hosts.iter() {
if reg.as_str().len() > n {
@@ -151,23 +139,20 @@ fn main() {
}
}
"config" => {
let cmd = Command::new(CONFIG_COMMAND).arg(&config_path).status();
let cmd = Command::new("vim").arg(&config_path).status();
match cmd {
Ok(status) => {
if status.success() {
config_parse(&config_path);
config_parse(&config_path).await;
} else {
println!(
"'{}' exits with a non-zero status code: {:?}",
CONFIG_COMMAND, status
);
println!("'vim' exits with a non-zero status code: {:?}", status);
}
}
Err(err) => exit!("Call '{}' command failed\n{:?}", CONFIG_COMMAND, err),
Err(err) => exit!("Call 'vim' command failed\n{:?}", err),
}
}
"path" => {
let binary = match env::current_exe() {
let binary = match current_exe() {
Ok(p) => p.display().to_string(),
Err(err) => exit!("Failed to get directory\n{:?}", err),
};
@@ -177,20 +162,14 @@ fn main() {
config_path.to_string_lossy()
);
}
"help" => {
app.help();
}
"version" => {
app.version();
}
_ => {
app.error_try("help");
}
"help" => app.help(),
"version" => app.version(),
_ => app.error_try("help"),
}
return;
}
let mut parse = config_parse(&config_path);
let mut parse = config_parse(&config_path).await;
if parse.bind.is_empty() {
warn!("Will bind the default address '{}'", DEFAULT_BIND);
parse.bind.push(DEFAULT_BIND.parse().unwrap());
@@ -206,26 +185,10 @@ fn main() {
// Run server
for addr in parse.bind {
task::spawn(run_server(addr.clone()));
tokio::spawn(run_server(addr.clone()));
}
// watch config
task::block_on(watch_config(config_path, watch_interval));
}
fn ask(text: &str) -> io::Result<bool> {
use std::io;
use std::io::Write;
io::stdout().write(text.as_bytes())?;
io::stdout().flush()?;
let mut s = String::new();
io::stdin().read_line(&mut s)?;
match s.to_uppercase().as_str() {
"Y\n" => Ok(true),
"N\n" => Ok(false),
_ => Ok(ask(&text)?),
}
watch_config(config_path, watch_interval).await;
}
fn update_config(mut proxy: Vec<SocketAddr>, hosts: Hosts, timeout: Option<u64>) {
@@ -245,13 +208,13 @@ fn update_config(mut proxy: Vec<SocketAddr>, hosts: Hosts, timeout: Option<u64>)
};
}
fn config_parse(file: &PathBuf) -> ParseConfig {
let mut config = match Config::new(file) {
async fn config_parse(file: &PathBuf) -> ParseConfig {
let config = match Config::new(file).await {
Ok(c) => c,
Err(err) => exit!("Failed to read config file {:?}\n{:?}", file, err),
};
let parse = match config.parse() {
let parse: ParseConfig = match config.parse().await {
Ok(d) => d,
Err(err) => exit!("Parsing config file failed\n{:?}", err),
};
@@ -265,29 +228,27 @@ fn output_invalid(errors: &Vec<Invalid>) {
error!(
"[line:{}] {} `{}`",
invalid.line,
invalid.kind.as_str(),
invalid.kind.text(),
invalid.source
);
}
}
async fn watch_config(p: PathBuf, t: u64) {
let mut watch = Watch::new(p, t);
watch
.for_each(|c| {
info!("Reload the configuration file: {:?}", &c);
if let Ok(mut config) = Config::new(c) {
if let Ok(parse) = config.parse() {
update_config(parse.proxy, parse.hosts, parse.timeout);
output_invalid(&parse.invalid);
}
let mut watch = Watch::new(&p, t).await;
while let Some(_) = watch.next().await {
info!("Reload the configuration file: {:?}", &p);
if let Ok(config) = Config::new(&p).await {
if let Ok(parse) = config.parse().await {
update_config(parse.proxy, parse.hosts, parse.timeout);
output_invalid(&parse.invalid);
}
})
.await;
}
}
}
async fn run_server(addr: SocketAddr) {
let socket = match UdpSocket::bind(&addr).await {
let mut socket = match UdpSocket::bind(&addr).await {
Ok(socket) => {
info!("Start listening to '{}'", addr);
socket
@@ -317,19 +278,22 @@ async fn run_server(addr: SocketAddr) {
}
}
async fn proxy(buf: &[u8]) -> io::Result<Vec<u8>> {
async fn proxy(buf: &[u8]) -> Result<Vec<u8>> {
let proxy = unsafe { &PROXY };
for addr in proxy.iter() {
let socket = UdpSocket::bind(("0.0.0.0", 0)).await?;
let mut socket = UdpSocket::bind(("0.0.0.0", 0)).await?;
let data = io::timeout(Duration::from_millis(unsafe { TIMEOUT }), async {
socket.send_to(&buf, addr).await?;
let mut res = [0; 512];
let len = socket.recv(&mut res).await?;
Ok(res[..len].to_vec())
})
.await;
let data: Result<Vec<u8>> = Timeout::new(
async {
socket.send_to(&buf, addr).await?;
let mut res = [0; 512];
let len = socket.recv(&mut res).await?;
Ok(res[..len].to_vec())
},
Duration::from_millis(unsafe { TIMEOUT }),
)
.await?;
match data {
Ok(data) => {
@@ -341,8 +305,8 @@ async fn proxy(buf: &[u8]) -> io::Result<Vec<u8>> {
}
}
Err(io::Error::new(
io::ErrorKind::Other,
Err(Error::new(
ErrorKind::Other,
"Proxy server failed to proxy request",
))
}
@@ -375,7 +339,7 @@ fn get_answer(domain: &str, query: QueryType) -> Option<DnsRecord> {
None
}
async fn handle(mut req: BytePacketBuffer, len: usize) -> io::Result<Vec<u8>> {
async fn handle(mut req: BytePacketBuffer, len: usize) -> Result<Vec<u8>> {
let mut request = DnsPacket::from_buffer(&mut req)?;
let query = match request.questions.get(0) {