use std::io::{Error, ErrorKind}; use std::time::Duration; use futures::future::try_join; use quinn::{RecvStream, SendStream}; use rust_util::util_msg; use rust_util::util_msg::MessageType; use tokio::{select, time}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[derive(Debug)] enum StreamDirection { Up, Down, } pub async fn transfer_for_server_to_remote(recv: RecvStream, send: SendStream, remote_addr: &str, local_addr: &str, proxy_addr: String, conn_count: String) -> Result<(), String> { let mut outbound = match TcpStream::connect(&proxy_addr).await { Ok(outbound) => outbound, Err(e) => { return Err(format!("[conn {}] Failed to connect to: {}, err: {}", &conn_count, &proxy_addr, e)); } }; if let (in_peer_addr, in_local_addr, Ok(ref out_local_addr), Ok(ref out_peer_addr)) = (remote_addr, local_addr, outbound.local_addr(), outbound.peer_addr()) { let peer = format!("{} -> [{} * {}] -> {}", in_peer_addr, in_local_addr, out_local_addr, out_peer_addr); information!("[conn {}] New server-remote tcp connection: {}", &conn_count, peer); } let (mut ri, mut wi) = (recv, send); let (mut ro, mut wo) = outbound.split(); inner_transfer(&mut ri, &mut wi, &mut ro, &mut wo, conn_count).await } pub async fn transfer_for_client_to_server(mut inbound: TcpStream, recv: RecvStream, send: SendStream, remote_addr: &str, local_addr: &str, conn_count: String) -> Result<(), String> { if let (Ok(ref in_peer_addr), Ok(ref in_local_addr), out_local_addr, out_peer_addr) = (inbound.peer_addr(), inbound.local_addr(), local_addr, remote_addr) { let peer = format!("{} -> [{} * {}] -> {}", in_peer_addr, in_local_addr, out_local_addr, out_peer_addr); information!("[conn {}] New client-server tcp connection: {}", &conn_count, peer); } let (mut ri, mut wi) = inbound.split(); let (mut ro, mut wo) = (recv, send); inner_transfer(&mut ri, &mut wi, &mut ro, &mut wo, conn_count).await } async fn inner_transfer<'a, R1, W1, R2, W2>(mut ri: &'a mut R1, mut wi: &'a mut W1, mut ro: &'a mut R2, mut wo: &'a mut W2, conn_count: String) -> Result<(), String> where R1: AsyncRead + Unpin + ?Sized, W1: AsyncWrite + Unpin + ?Sized, R2: AsyncRead + Unpin + ?Sized, W2: AsyncWrite + Unpin + ?Sized { // IO copy timeout 6 HOURS let tcp_io_copy_timeout = Duration::from_secs(6 * 3600); let shutdown_tcp_timeout = Duration::from_secs(60); let (client_to_server_tx, client_to_server_rx) = tokio::sync::oneshot::channel::(); let (server_to_client_tx, server_to_client_rx) = tokio::sync::oneshot::channel::(); let client_to_server = async { let r = select! { _timeout = time::sleep(tcp_io_copy_timeout) => { failure!("[conn {}] TCP copy client -> server timeout", &conn_count); Err(Error::new(ErrorKind::TimedOut, "timeout")) } _tcp_break = client_to_server_rx => { failure!("[conn {}] TCP copy client -> server shutdown", &conn_count); Err(Error::new(ErrorKind::BrokenPipe, "shutdown")) } data_copy_result = copy_data(&mut ri, &mut wo, StreamDirection::Up, &conn_count) => { match data_copy_result { Err(e) => { failure!("[conn {}] TCP copy client -> server error: {}", &conn_count, e); Err(e) }, Ok(r) => { information!("[conn {}] TCP copy client -> server success: {} byte(s)", &conn_count, r); Ok(r) }, } } }; information!("[conn {}] Close client to server connection", &conn_count); match time::timeout(shutdown_tcp_timeout, wo.shutdown()).await { Err(e) => warning!("[conn {}] TCP close client -> server timeout: {}", &conn_count, e), Ok(Err(e)) => warning!("[conn {}] TCP close client -> server error: {}", &conn_count, e), _ => {} } time::sleep(Duration::from_secs(2)).await; let _ = server_to_client_tx.send(true); r }; let server_to_client = async { let r = select! { _timeout = time::sleep(tcp_io_copy_timeout) => { failure!("[conn {}] TCP copy server -> client timeout", &conn_count); Err(Error::new(ErrorKind::TimedOut, "timeout")) } _tcp_break = server_to_client_rx => { failure!("[conn {}] TCP copy server -> client shutdown", &conn_count); Err(Error::new(ErrorKind::BrokenPipe, "shutdown")) } data_copy_result = copy_data(&mut ro, &mut wi, StreamDirection::Down, &conn_count) => { match data_copy_result { Err(e) => { failure!("[conn {}] TCP copy server -> client error: {}", &conn_count, e); Err(e) }, Ok(r) => { information!("[conn {}] TCP copy server -> client success: {} byte(s)", &conn_count, r); Ok(r) }, } } }; information!("[conn {}] Close server to client connection", &conn_count); match time::timeout(shutdown_tcp_timeout, wi.shutdown()).await { Err(e) => warning!("[conn {}] TCP close server -> client timeout: {}", &conn_count, e), Ok(Err(e)) => warning!("[conn {}] TCP close server -> client error: {}", &conn_count, e), _ => {} } time::sleep(Duration::from_secs(2)).await; let _ = client_to_server_tx.send(true); r }; let r = match try_join(client_to_server, server_to_client).await { Err(e) => Err(format!("[conn {}] Failed try_join: {}", &conn_count, e)), Ok((upstream_bytes, downstream_bytes)) => { information!("[conn {}] Finished, proxy-in: {} bytes, proxy-out: {} bytes", &conn_count, upstream_bytes, downstream_bytes ); Ok(()) } }; r } // fn is_in_peer_addr_matches(inbound: &TcpStream, allow_ips: &[IpAddressMask]) -> bool { // if let Ok(ref in_peer_addr) = inbound.peer_addr() { // if let Some(ip_filter) = &*GLOBAL_TEMP_IP_FILTER.read().unwrap() { // if let Some(ip) = ip_filter.is_ip_address_matches(in_peer_addr) { // return true; // } // } // if !allow_ips.is_empty() { // // ONLY if allow ips is not config, returns true // // If default deny IPs, should config like: "allow_ips": ["127.0.0.1"] // return allow_ips.iter().any(|ip| ip.is_matches(in_peer_addr)); // } // } // true // } async fn copy_data<'a, R, W>( reader: &'a mut R, writer: &'a mut W, direction: StreamDirection, conn_count: &str) -> tokio::io::Result where R: AsyncRead + Unpin + ?Sized, W: AsyncWrite + Unpin + ?Sized { let mut total_copied_bytes = 0_u64; let mut buff = vec![0; 1024 * 4]; loop { match reader.read(&mut buff).await { Ok(0) => return Ok(total_copied_bytes), Ok(n) => { util_msg::when(MessageType::DEBUG, || { debugging!("[conn {}] Direction: {:?}, {} bytes: {:02x?}", conn_count, direction, n, &buff[..n]); debugging!("[conn {}] Direction: {:?}, {} string: {}", conn_count, direction, n, String::from_utf8_lossy(&buff[..n]).to_string()); }); match writer.write_all(&buff[..n]).await { Ok(_) => { total_copied_bytes += n as u64; } Err(e) => return Err(e), } } Err(e) => return Err(e), } } }