From b1a8533f1fe4ea61fa886676f5346d4d524026a6 Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Sun, 19 Jun 2022 11:14:33 +0800 Subject: [PATCH] feat: tcp data copy --- Cargo.toml | 2 +- src/server.rs | 86 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 279276e..539930b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ socks4 = [] [dependencies] log = "0.4" -tokio = { version = "1.17.0", features = ["io-util", "net", "time"] } +tokio = { version = "1.17.0", features = ["io-util", "net", "time", "macros"] } anyhow = "1.0" thiserror = "1.0" tokio-stream = "0.1.8" diff --git a/src/server.rs b/src/server.rs index 55cec0b..344cb74 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,6 +8,7 @@ use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError}; use anyhow::Context; use std::future::Future; use std::io; +use std::io::{Error, ErrorKind}; use std::net::IpAddr; use std::net::Ipv4Addr; use std::net::{SocketAddr, ToSocketAddrs as StdToSocketAddrs}; @@ -19,7 +20,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::UdpSocket; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs}; use tokio::time::timeout; -use tokio::{time, try_join}; +use tokio::{select, time, try_join}; use tokio_stream::Stream; use futures::future::try_join; @@ -663,37 +664,76 @@ async fn transfer(inbound: I, outbound: O) -> Result<()> // IO copy timeout 6 HOURS let tcp_io_copy_timeout = Duration::from_secs(6 * 3600); let shutdown_tcp_timeout = Duration::from_secs(60); + let (client_to_server_tx, client_to_server_rx) = tokio::sync::oneshot::channel::(); + let (server_to_client_tx, server_to_client_rx) = tokio::sync::oneshot::channel::(); + let client_to_server = async move { // let copy_result = time::timeout(tcp_io_copy_timeout, tokio::io::copy(&mut ri, &mut wo)); - let copy_result = time::timeout( - tcp_io_copy_timeout, - copy_data(&mut ri, &mut wo), - ); - let r = copy_result.await; - time::timeout(shutdown_tcp_timeout, wo.shutdown()).await.ok(); - match r { - Err(e) => { - error!("TCP copy timeout: {}", e); - Err(e.into()) + let r = select! { + _timeout = time::sleep(tcp_io_copy_timeout) => { + error!("TCP copy client -> server timeout"); + Err(Error::new(ErrorKind::TimedOut, "timeout")) } - Ok(r) => r, + _tcp_break = client_to_server_rx => { + error!("TCP copy client -> server shutdown"); + Err(Error::new(ErrorKind::BrokenPipe, "shutdown")) + } + data_copy_result = copy_data(&mut ri, &mut wo) => { + match data_copy_result { + Err(e) => { + error!("TCP copy client -> server error: {}", e); + Err(e) + }, + Ok(r) => { + info!("TCP copy client -> server success: {} byte(s)", r); + Ok(r) + }, + } + } + }; + info!("Close client to server connection"); + match time::timeout(shutdown_tcp_timeout, wo.shutdown()).await { + Err(e) => warn!("TCP close client -> server timeout: {}", e), + Ok(Err(e)) => warn!("TCP close client -> server error: {}", e), + _ => {}, } + time::sleep(Duration::from_secs(2)).await; + let _ = server_to_client_tx.send(true); + r }; let server_to_client = async move { // let copy_result = time::timeout(tcp_io_copy_timeout, tokio::io::copy(&mut ro, &mut wi)); - let copy_result = time::timeout( - tcp_io_copy_timeout, - copy_data(&mut ro, &mut wi), - ); - let r = copy_result.await; - time::timeout(shutdown_tcp_timeout, wi.shutdown()).await.ok(); - match r { - Err(e) => { - error!("TCP copy timeout: {}", e); - Err(e.into()) + let r = select! { + _timeout = time::sleep(tcp_io_copy_timeout) => { + error!("TCP copy server -> client timeout"); + Err(Error::new(ErrorKind::TimedOut, "timeout")) } - Ok(r) => r, + _tcp_break = server_to_client_rx => { + error!("TCP copy server -> client shutdown"); + Err(Error::new(ErrorKind::BrokenPipe, "shutdown")) + } + data_copy_result = copy_data(&mut ro, &mut wi) => { + match data_copy_result { + Err(e) => { + error!("TCP copy server -> client error: {}", e); + Err(e) + }, + Ok(r) => { + info!("TCP copy server -> client success: {} byte(s)", r); + Ok(r) + }, + } + } + }; + info!("Close server to client connection"); + match time::timeout(shutdown_tcp_timeout, wi.shutdown()).await { + Err(e) => warn!("TCP close server -> client timeout: {}", e), + Ok(Err(e)) => warn!("TCP close server -> client error: {}", e), + _ => {}, } + time::sleep(Duration::from_secs(2)).await; + let _ = client_to_server_tx.send(true); + r }; match try_join(client_to_server, server_to_client).await { Err(e) => Err(SocksError::Io(e)),