diff --git a/Cargo.toml b/Cargo.toml index 2877e18..279276e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ tokio = { version = "1.17.0", features = ["io-util", "net", "time"] } anyhow = "1.0" thiserror = "1.0" tokio-stream = "0.1.8" +futures = "0.3" # Dependencies for examples and tests [dev-dependencies] diff --git a/src/server.rs b/src/server.rs index e8e3887..1d9590a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,8 +19,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::UdpSocket; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs}; use tokio::time::timeout; -use tokio::try_join; +use tokio::{time, try_join}; use tokio_stream::Stream; +use futures::future::try_join; #[derive(Clone)] pub struct Config { @@ -140,7 +141,7 @@ impl Socks5Server { /// `Incoming` implements [`futures::stream::Stream`]. pub struct Incoming<'a>( &'a Socks5Server, - Option> + Send + Sync + 'a>>>, + Option> + Send + Sync + 'a>>>, ); /// Iterator for each incoming stream connection @@ -559,16 +560,16 @@ impl Socks5Socket { Err(e) => match e.kind() { // Match other TCP errors with ReplyError io::ErrorKind::ConnectionRefused => { - return Err(ReplyError::ConnectionRefused.into()) + return Err(ReplyError::ConnectionRefused.into()); } io::ErrorKind::ConnectionAborted => { - return Err(ReplyError::ConnectionNotAllowed.into()) + return Err(ReplyError::ConnectionNotAllowed.into()); } io::ErrorKind::ConnectionReset => { - return Err(ReplyError::ConnectionNotAllowed.into()) + return Err(ReplyError::ConnectionNotAllowed.into()); } io::ErrorKind::NotConnected => { - return Err(ReplyError::NetworkUnreachable.into()) + return Err(ReplyError::NetworkUnreachable.into()); } _ => return Err(e.into()), // #[error("General failure")] ? }, @@ -640,17 +641,66 @@ impl Socks5Socket { /// Copy data between two peers /// Using 2 different generators, because they could be different structs with same traits. -async fn transfer(mut inbound: I, mut outbound: O) -> Result<()> -where - I: AsyncRead + AsyncWrite + Unpin, - O: AsyncRead + AsyncWrite + Unpin, +async fn transfer(inbound: I, outbound: O) -> Result<()> + where + I: AsyncRead + AsyncWrite + Unpin, + O: AsyncRead + AsyncWrite + Unpin, { - match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await { - Ok(res) => info!("transfer closed ({}, {})", res.0, res.1), - Err(err) => error!("transfer error: {:?}", err), - }; + // match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await { + // Ok(res) => info!("transfer closed ({}, {})", res.0, res.1), + // Err(err) => error!("transfer error: {:?}", err), + // }; + // Ok(()) - Ok(()) + // if let (Ok(ref in_peer_addr), Ok(ref in_local_addr), Ok(ref out_local_addr), Ok(ref out_peer_addr)) + // = (inbound.peer_addr(), inbound.local_addr(), outbound.local_addr(), outbound.peer_addr()) { + // information!("[conn {}] New tcp connection: {} -> [{} * {}] -> {}", + // &conn_count, in_peer_addr, in_local_addr, out_local_addr, out_peer_addr); + // } + + let (mut ri, mut wi) = tokio::io::split(inbound); + let (mut ro, mut wo) = tokio::io::split(outbound); + // IO copy timeout 12 HOURS + let tcp_io_copy_timeout = Duration::from_secs(12 * 3600); + let client_to_server = async { + // let copy_result = time::timeout(tcp_io_copy_timeout, tokio::io::copy(&mut ri, &mut wo)); + let copy_result = time::timeout( + tcp_io_copy_timeout, + copy_data(&mut ri, &mut wo), + ); + let r = copy_result.await; + wo.shutdown().await.ok(); + match r { + Err(e) => { + error!("TCP copy timeout: {}", e); + Err(e.into()) + } + Ok(r) => r, + } + }; + let server_to_client = async { + // let copy_result = time::timeout(tcp_io_copy_timeout, tokio::io::copy(&mut ro, &mut wi)); + let copy_result = time::timeout( + tcp_io_copy_timeout, + copy_data(&mut ro, &mut wi), + ); + let r = copy_result.await; + wi.shutdown().await.ok(); + match r { + Err(e) => { + error!("TCP copy timeout: {}", e); + Err(e.into()) + } + Ok(r) => r, + } + }; + match try_join(client_to_server, server_to_client).await { + Err(e) => Err(SocksError::Io(e)), + Ok((upstream_bytes, downstream_bytes)) => { + info!("Finished, proxy-in: {} bytes, proxy-out: {} bytes", upstream_bytes, downstream_bytes); + Ok(()) + } + } } async fn handle_udp_request(inbound: &UdpSocket, outbound: &UdpSocket) -> Result<()> { @@ -708,8 +758,8 @@ async fn transfer_udp(inbound: UdpSocket) -> Result<()> { /// Allow us to read directly from the struct impl AsyncRead for Socks5Socket -where - T: AsyncRead + AsyncWrite + Unpin, + where + T: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, @@ -722,8 +772,8 @@ where /// Allow us to write directly into the struct impl AsyncWrite for Socks5Socket -where - T: AsyncRead + AsyncWrite + Unpin, + where + T: AsyncRead + AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, @@ -775,6 +825,25 @@ fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec { reply } +async fn copy_data<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> tokio::io::Result + where R: AsyncRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized { + let mut total_copied_bytes = 0_u64; + let mut buff = vec![0; 1024 * 4]; + loop { + match reader.read(&mut buff).await { + Ok(0) => return Ok(total_copied_bytes), + Ok(n) => match writer.write_all(&buff[..n]).await { + Ok(_) => { + total_copied_bytes += n as u64; + } + Err(e) => return Err(e), + } + Err(e) => return Err(e), + } + } +} + #[cfg(test)] mod test { use crate::server::Socks5Server;