init commit from github.com/dizda/fast-socks5

This commit is contained in:
2022-06-08 00:40:54 +08:00
parent c7c77aaec0
commit c6b81bd13b
14 changed files with 3119 additions and 6 deletions

650
src/client.rs Normal file
View File

@@ -0,0 +1,650 @@
#[forbid(unsafe_code)]
use crate::read_exact;
use crate::util::target_addr::{read_address, TargetAddr, ToTargetAddr};
use crate::{
consts, new_udp_header, parse_udp_request, AuthenticationMethod, ReplyError, Result,
Socks5Command, SocksError,
};
use anyhow::Context;
use std::io;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use std::task::Poll;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
const MAX_ADDR_LEN: usize = 260;
#[derive(Debug)]
pub struct Config {
/// Avoid useless roundtrips if we don't need the Authentication layer
/// make sure to also activate it on the server side.
skip_auth: bool,
}
impl Default for Config {
fn default() -> Self {
Config { skip_auth: false }
}
}
impl Config {
pub fn set_skip_auth(&mut self, value: bool) -> &mut Self {
self.skip_auth = value;
self
}
}
/// A SOCKS5 client.
/// `Socks5Stream` implements [`AsyncRead`] and [`AsyncWrite`].
#[derive(Debug)]
pub struct Socks5Stream<S: AsyncRead + AsyncWrite + Unpin> {
socket: S,
target_addr: Option<TargetAddr>,
config: Config,
}
impl<S> Socks5Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
/// Possibility to use a stream already created rather than
/// creating a whole new `TcpStream::connect()`.
pub async fn use_stream(
socket: S,
auth: Option<AuthenticationMethod>,
config: Config,
) -> Result<Self> {
let mut stream = Socks5Stream {
socket,
config,
target_addr: None,
};
// Auth none is always used by default.
let mut methods = vec![AuthenticationMethod::None];
if let Some(method) = auth {
// add any other method if supplied
methods.push(method);
}
// Handshake Lifecycle
if !stream.config.skip_auth {
let methods = stream.send_version_and_methods(methods).await?;
stream.which_method_accepted(methods).await?;
} else {
debug!("skipping auth");
}
Ok(stream)
}
pub async fn request(
&mut self,
cmd: Socks5Command,
target_addr: TargetAddr,
) -> Result<TargetAddr> {
self.target_addr = Some(target_addr);
// Request Lifecycle
info!("Requesting headers `{:?}`...", &self.target_addr);
self.request_header(cmd).await?;
let bind_addr = self.read_request_reply().await?;
Ok(bind_addr)
}
/// Decide to whether or not, accept the authentication method
/// A client send a list of methods that he supports, he could send
///
/// - 0: Non auth
/// - 2: Auth with username/password
///
/// Altogether, then the server choose to use of of these,
/// or deny the handshake (thus the connection).
///
/// # Examples
/// ```text
/// {SOCKS Version, methods-length}
/// eg. (non-auth) {5, 2}
/// eg. (auth) {5, 3}
/// ```
///
async fn send_version_and_methods(
&mut self,
methods: Vec<AuthenticationMethod>,
) -> Result<Vec<AuthenticationMethod>> {
debug!(
"Send version and method len [{}, {}]",
consts::SOCKS5_VERSION,
methods.len()
);
// write the first 2 bytes which contains the SOCKS version and the methods len()
self.socket
.write(&[consts::SOCKS5_VERSION, methods.len() as u8])
.await
.context("Couldn't write SOCKS version & methods len")?;
let auth = methods.iter().map(|l| l.as_u8()).collect::<Vec<_>>();
debug!("client auth methods supported: {:?}", &auth);
self.socket
.write(&auth)
.await
.context("Couldn't write supported auth methods")?;
// Return methods available
Ok(methods)
}
/// Decide to whether or not, accept the authentication method.
/// Don't forget that the methods list sent by the client, contains one or more methods.
///
/// # Request
///
/// Client send an array of 3 entries: [0, 1, 2]
/// ```text
/// {SOCKS Version, Authentication chosen}
/// eg. (non-auth) {5, 0}
/// eg. (GSSAPI) {5, 1}
/// eg. (auth) {5, 2}
/// ```
///
/// # Response
/// ```text
/// eg. (accept non-auth) {5, 0x00}
/// eg. (non-acceptable) {5, 0xff}
/// ```
///
async fn which_method_accepted(&mut self, methods: Vec<AuthenticationMethod>) -> Result<()> {
let [version, method] =
read_exact!(self.socket, [0u8; 2]).context("Can't get chosen auth method")?;
debug!(
"Socks version ({version}), method chosen: {method}.",
version = version,
method = method,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
match method {
consts::SOCKS5_AUTH_METHOD_NONE => info!("No auth will be used"),
consts::SOCKS5_AUTH_METHOD_PASSWORD => self.use_password_auth(methods).await?,
_ => {
debug!("Don't support this auth method, reply with (0xff)");
self.socket
.write(&[
consts::SOCKS5_VERSION,
consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
])
.await
.context("Can't write that the methods are unsupported.")?;
return Err(SocksError::AuthMethodUnacceptable(vec![method]));
}
}
Ok(())
}
async fn use_password_auth(&mut self, methods: Vec<AuthenticationMethod>) -> Result<()> {
info!("Password will be used");
let (username, password) = match methods[1] {
AuthenticationMethod::None => unreachable!(),
AuthenticationMethod::Password {
ref username,
ref password,
} => (username, password),
};
let user_bytes = username.as_bytes();
let pass_bytes = password.as_bytes();
// send username len
self.socket
.write(&[1, user_bytes.len() as u8])
.await
.context("Can't send username len")?;
self.socket
.write(user_bytes)
.await
.context("Can't send username")?;
// send password len
self.socket
.write(&[pass_bytes.len() as u8])
.await
.context("Can't send password len")?;
self.socket
.write(pass_bytes)
.await
.context("Can't send password")?;
// Check the server reply, if whether it approved the auth or not
let [version, is_success] =
read_exact!(self.socket, [0u8; 2]).context("Can't read is_success")?;
debug!(
"Auth: [version: {version}, is_success: {is_success}]",
version = version,
is_success = is_success,
);
if is_success != consts::SOCKS5_REPLY_SUCCEEDED {
return Err(SocksError::AuthenticationRejected(format!(
"Authentication with username `{}`, rejected.",
username
)));
}
Ok(())
}
/// Decide to whether or not, accept the authentication method.
/// Don't forget that the methods list sent by the client, contains one or more methods.
///
/// # Request
/// ```test
/// +----+-----+-------+------+----------+----------+
/// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
/// +----+-----+-------+------+----------+----------+
/// | 1 | 1 | 1 | 1 | Variable | 2 |
/// +----+-----+-------+------+----------+----------+
/// ```
///
/// # Help
///
/// To debug request use a netcat server with hexadecimal output to parse the hidden bytes:
///
/// ```bash
/// $ nc -k -l 80 | hexdump -C
/// ```
///
async fn request_header(&mut self, cmd: Socks5Command) -> Result<()> {
let mut packet = [0u8; MAX_ADDR_LEN + 3];
let padding; // maximum len of the headers sent
// build our request packet with (socks version, Command, reserved)
packet[..3].copy_from_slice(&[consts::SOCKS5_VERSION, cmd.as_u8(), 0x00]);
match self.target_addr.as_ref() {
None => {
if cmd == Socks5Command::UDPAssociate {
debug!("UDPAssociate without target_addr, fallback to zeros.");
padding = 10;
packet[3] = 0x01;
packet[4..8].copy_from_slice(&[0, 0, 0, 0]); // ip
packet[8..padding].copy_from_slice(&[0, 0]); // port
} else {
return Err(anyhow::Error::msg("target addr should be present").into());
}
}
Some(target_addr) => match target_addr {
TargetAddr::Ip(SocketAddr::V4(addr)) => {
debug!("TargetAddr::IpV4");
padding = 10;
packet[3] = 0x01;
debug!("addr ip {:?}", (*addr.ip()).octets());
packet[4..8].copy_from_slice(&(addr.ip()).octets()); // ip
packet[8..padding].copy_from_slice(&addr.port().to_be_bytes());
// port
}
TargetAddr::Ip(SocketAddr::V6(addr)) => {
debug!("TargetAddr::IpV6");
padding = 22;
packet[3] = 0x04;
debug!("addr ip {:?}", (*addr.ip()).octets());
packet[4..20].copy_from_slice(&(addr.ip()).octets()); // ip
packet[20..padding].copy_from_slice(&addr.port().to_be_bytes());
// port
}
TargetAddr::Domain(ref domain, port) => {
debug!("TargetAddr::Domain");
if domain.len() > u8::MAX as usize {
return Err(SocksError::ExceededMaxDomainLen(domain.len()));
}
padding = 5 + domain.len() + 2;
packet[3] = 0x03; // Specify domain type
packet[4] = domain.len() as u8; // domain length
packet[5..(5 + domain.len())].copy_from_slice(domain.as_bytes()); // domain content
packet[(5 + domain.len())..padding].copy_from_slice(&port.to_be_bytes());
// port content (.to_be_bytes() convert from u16 to u8 type)
}
},
}
debug!("Bytes long version: {:?}", &packet[..]);
debug!("Bytes shorted version: {:?}", &packet[..padding]);
debug!("Padding: {}", &padding);
// we limit the end of the packet right after the domain + port number, we don't need to print
// useless 0 bytes, otherwise other protocol won't understand the request (like HTTP servers).
self.socket
.write(&packet[..padding])
.await
.context("Can't write request header's packet.")?;
self.socket
.flush()
.await
.context("Can't flush request header's packet")?;
Ok(())
}
/// The server send a confirmation (reply) that he had successfully connected (or not) to the
/// remote server.
async fn read_request_reply(&mut self) -> Result<TargetAddr> {
let [version, reply, rsv, address_type] =
read_exact!(self.socket, [0u8; 4]).context("Received malformed reply")?;
debug!(
"Reply received: [version: {version}, reply: {reply}, rsv: {rsv}, address_type: {address_type}]",
version = version,
reply = reply,
rsv = rsv,
address_type = address_type,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
if reply != consts::SOCKS5_REPLY_SUCCEEDED {
return Err(ReplyError::from_u8(reply).into()); // Convert reply received into correct error
}
let address = read_address(&mut self.socket, address_type).await?;
info!("Remote server bind on {}.", address);
Ok(address)
}
pub fn get_socket(self) -> S {
self.socket
}
pub fn get_socket_ref(&self) -> &S {
&self.socket
}
pub fn get_socket_mut(&mut self) -> &mut S {
&mut self.socket
}
}
/// A SOCKS5 UDP client.
#[derive(Debug)]
pub struct Socks5Datagram<S: AsyncRead + AsyncWrite + Unpin> {
socket: UdpSocket,
// keeps the session alive
#[allow(dead_code)]
stream: Socks5Stream<S>,
proxy_addr: Option<TargetAddr>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> Socks5Datagram<S> {
/// Creates a UDP socket bound to the specified address which will have its
/// traffic routed through the specified proxy.
///
/// # Arguments
/// * `backing_socket` - The underlying socket carrying the socks5 traffic.
/// * `client_bind_addr` - A socket address indicates the binding source address used to
/// communicate with the socks5 server.
///
/// # Examples
/// ```ignore
/// let backing_socket = TcpStream::connect("127.0.0.1:1080").await.unwrap();
/// let tunnel = client::Socks5Datagram::bind(backing_socket, "[::]:0")
/// .await
/// .unwrap();
/// ```
pub async fn bind<U>(backing_socket: S, client_bind_addr: U) -> Result<Socks5Datagram<S>>
where
U: ToSocketAddrs,
{
Self::bind_internal(backing_socket, client_bind_addr, None).await
}
/// Creates a UDP socket bound to the specified address which will have its
/// traffic routed through the specified proxy. The given username and password
/// is used to authenticate to the SOCKS proxy.
pub async fn bind_with_password<U>(
backing_socket: S,
client_bind_addr: U,
username: &str,
password: &str,
) -> Result<Socks5Datagram<S>>
where
U: ToSocketAddrs,
{
let auth = AuthenticationMethod::Password {
username: username.to_owned(),
password: password.to_owned(),
};
Self::bind_internal(backing_socket, client_bind_addr, Some(auth)).await
}
async fn bind_internal<U>(
backing_socket: S,
client_bind_addr: U,
auth: Option<AuthenticationMethod>,
) -> Result<Socks5Datagram<S>>
where
U: ToSocketAddrs,
{
let client_bind_addr = client_bind_addr
.to_socket_addrs()?
.next()
.context("unreachable")?;
let out_sock = UdpSocket::bind(client_bind_addr).await?;
info!("UdpSocket client socket bind to {}", client_bind_addr);
// Init socks5 stream.
let mut proxy_stream =
Socks5Stream::use_stream(backing_socket, auth, Config::default()).await?;
// we don't know what our IP is from the perspective of the proxy, so
// don't try to pass `addr` in here.
let client_src = TargetAddr::Ip("[::]:0".parse().unwrap());
let proxy_addr = proxy_stream
.request(Socks5Command::UDPAssociate, client_src)
.await?;
let proxy_addr_resolved = proxy_addr
.to_socket_addrs()?
.next()
.context("unreachable")?;
out_sock.connect(proxy_addr_resolved).await?;
info!("UdpSocket client connected to {}", proxy_addr_resolved);
Ok(Socks5Datagram {
socket: out_sock,
stream: proxy_stream,
proxy_addr: Some(proxy_addr),
})
}
/// Like `UdpSocket::send_to`.
///
/// # Note
///
/// The SOCKS protocol inserts a header at the beginning of the message. The
/// header will be 10 bytes for an IPv4 address, 22 bytes for an IPv6
/// address, and 7 bytes plus the length of the domain for a domain address.
pub async fn send_to<A>(&self, data: &[u8], addr: A) -> Result<usize>
where
A: ToTargetAddr,
{
let mut buf = new_udp_header(addr)?;
buf.extend_from_slice(data);
return Ok(self.socket.send(&buf).await?);
}
/// Like `UdpSocket::recv_from`.
pub async fn recv_from(&self, data_store: &mut [u8]) -> Result<(usize, TargetAddr)> {
let mut buf = [0u8; 0x10000];
let (size, _) = self.socket.recv_from(&mut buf).await?;
let (frag, target_addr, data) = parse_udp_request(&mut buf[..size]).await?;
if frag != 0 {
return Err(SocksError::Other(anyhow::anyhow!(
"Unsupported frag value."
)));
}
data_store[..data.len()].copy_from_slice(data);
Ok((data.len(), target_addr))
}
/// Returns the address of the proxy-side UDP socket through which all
/// messages will be routed.
pub fn proxy_addr(&self) -> Result<&TargetAddr> {
Ok(self
.proxy_addr
.as_ref()
.context("proxy addr is not ready")?)
}
/// Returns a shared reference to the inner socket.
pub fn get_ref(&self) -> &UdpSocket {
&self.socket
}
/// Returns a mutable reference to the inner socket.
pub fn get_mut(&mut self) -> &mut UdpSocket {
&mut self.socket
}
}
/// Api if you want to use TcpStream to create a new connection to the SOCKS5 server.
impl Socks5Stream<TcpStream> {
/// Connects to a target server through a SOCKS5 proxy.
pub async fn connect<T>(
socks_server: T,
target_addr: String,
target_port: u16,
config: Config,
) -> Result<Self>
where
T: ToSocketAddrs,
{
Self::connect_raw(
Socks5Command::TCPConnect,
socks_server,
target_addr,
target_port,
None,
config,
)
.await
}
/// Connect with credentials
pub async fn connect_with_password<T>(
socks_server: T,
target_addr: String,
target_port: u16,
username: String,
password: String,
config: Config,
) -> Result<Self>
where
T: ToSocketAddrs,
{
let auth = AuthenticationMethod::Password { username, password };
Self::connect_raw(
Socks5Command::TCPConnect,
socks_server,
target_addr,
target_port,
Some(auth),
config,
)
.await
}
/// Process clients SOCKS requests
/// This is the entry point where a whole request is processed.
pub async fn connect_raw<T>(
cmd: Socks5Command,
socks_server: T,
target_addr: String,
target_port: u16,
auth: Option<AuthenticationMethod>,
config: Config,
) -> Result<Self>
where
T: ToSocketAddrs,
{
let socket = TcpStream::connect(
socks_server
.to_socket_addrs()?
.next()
.context("unreachable")?,
)
.await?;
info!("Connected @ {}", &socket.peer_addr()?);
// Specify the target, here domain name, dns will be resolved on the server side
let target_addr = (target_addr.as_str(), target_port)
.to_target_addr()
.context("Can't convert address to TargetAddr format")?;
// upgrade the TcpStream to Socks5Stream
let mut socks_stream = Self::use_stream(socket, auth, config).await?;
socks_stream.request(cmd, target_addr).await?;
Ok(socks_stream)
}
}
/// Allow us to read directly from the struct
impl<S> AsyncRead for Socks5Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.socket).poll_read(context, buf)
}
}
/// Allow us to write directly into the struct
impl<S> AsyncWrite for Socks5Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.socket).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_flush(context)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_shutdown(context)
}
}

460
src/lib.rs Normal file
View File

@@ -0,0 +1,460 @@
#[forbid(unsafe_code)]
#[macro_use]
extern crate log;
pub mod client;
pub mod server;
pub mod util;
#[cfg(feature = "socks4")]
pub mod socks4;
use anyhow::Context;
use std::fmt;
use std::io;
use thiserror::Error;
use util::target_addr::read_address;
use util::target_addr::TargetAddr;
use util::target_addr::ToTargetAddr;
use tokio::io::AsyncReadExt;
#[rustfmt::skip]
pub mod consts {
pub const SOCKS5_VERSION: u8 = 0x05;
pub const SOCKS5_AUTH_METHOD_NONE: u8 = 0x00;
pub const SOCKS5_AUTH_METHOD_GSSAPI: u8 = 0x01;
pub const SOCKS5_AUTH_METHOD_PASSWORD: u8 = 0x02;
pub const SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE: u8 = 0xff;
pub const SOCKS5_CMD_TCP_CONNECT: u8 = 0x01;
pub const SOCKS5_CMD_TCP_BIND: u8 = 0x02;
pub const SOCKS5_CMD_UDP_ASSOCIATE: u8 = 0x03;
pub const SOCKS5_ADDR_TYPE_IPV4: u8 = 0x01;
pub const SOCKS5_ADDR_TYPE_DOMAIN_NAME: u8 = 0x03;
pub const SOCKS5_ADDR_TYPE_IPV6: u8 = 0x04;
pub const SOCKS5_REPLY_SUCCEEDED: u8 = 0x00;
pub const SOCKS5_REPLY_GENERAL_FAILURE: u8 = 0x01;
pub const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED: u8 = 0x02;
pub const SOCKS5_REPLY_NETWORK_UNREACHABLE: u8 = 0x03;
pub const SOCKS5_REPLY_HOST_UNREACHABLE: u8 = 0x04;
pub const SOCKS5_REPLY_CONNECTION_REFUSED: u8 = 0x05;
pub const SOCKS5_REPLY_TTL_EXPIRED: u8 = 0x06;
pub const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
pub const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
}
#[derive(Debug, PartialEq)]
pub enum Socks5Command {
TCPConnect,
TCPBind,
UDPAssociate,
}
#[allow(dead_code)]
impl Socks5Command {
#[inline]
#[rustfmt::skip]
fn as_u8(&self) -> u8 {
match self {
Socks5Command::TCPConnect => consts::SOCKS5_CMD_TCP_CONNECT,
Socks5Command::TCPBind => consts::SOCKS5_CMD_TCP_BIND,
Socks5Command::UDPAssociate => consts::SOCKS5_CMD_UDP_ASSOCIATE,
}
}
#[inline]
#[rustfmt::skip]
fn from_u8(code: u8) -> Option<Socks5Command> {
match code {
consts::SOCKS5_CMD_TCP_CONNECT => Some(Socks5Command::TCPConnect),
consts::SOCKS5_CMD_TCP_BIND => Some(Socks5Command::TCPBind),
consts::SOCKS5_CMD_UDP_ASSOCIATE => Some(Socks5Command::UDPAssociate),
_ => None,
}
}
}
#[derive(Debug, PartialEq)]
pub enum AuthenticationMethod {
None,
Password { username: String, password: String },
}
impl AuthenticationMethod {
#[inline]
#[rustfmt::skip]
fn as_u8(&self) -> u8 {
match self {
AuthenticationMethod::None => consts::SOCKS5_AUTH_METHOD_NONE,
AuthenticationMethod::Password {..} =>
consts::SOCKS5_AUTH_METHOD_PASSWORD
}
}
#[inline]
#[rustfmt::skip]
fn from_u8(code: u8) -> Option<AuthenticationMethod> {
match code {
consts::SOCKS5_AUTH_METHOD_NONE => Some(AuthenticationMethod::None),
consts::SOCKS5_AUTH_METHOD_PASSWORD => Some(AuthenticationMethod::Password { username: "test".to_string(), password: "test".to_string()}),
_ => None,
}
}
}
impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
AuthenticationMethod::None => f.write_str("AuthenticationMethod::None"),
AuthenticationMethod::Password { .. } => f.write_str("AuthenticationMethod::Password"),
}
}
}
//impl Vec<AuthenticationMethod> {
// pub fn as_bytes(&self) -> &[u8] {
// self.iter().map(|l| l.as_u8()).collect()
// }
//}
//
//impl From<&[AuthenticationMethod]> for &[u8] {
// fn from(_: Vec<AuthenticationMethod>) -> Self {
// &[0x00]
// }
//}
#[derive(Error, Debug)]
pub enum SocksError {
#[error("i/o error: {0}")]
Io(#[from] io::Error),
#[error("the data for key `{0}` is not available")]
Redaction(String),
#[error("invalid header (expected {expected:?}, found {found:?})")]
InvalidHeader { expected: String, found: String },
#[error("Auth method unacceptable `{0:?}`.")]
AuthMethodUnacceptable(Vec<u8>),
#[error("Unsupported SOCKS version `{0}`.")]
UnsupportedSocksVersion(u8),
#[error("Domain exceeded max sequence length")]
ExceededMaxDomainLen(usize),
#[error("Authentication failed `{0}`")]
AuthenticationFailed(String),
#[error("Authentication rejected `{0}`")]
AuthenticationRejected(String),
#[error("Error with reply: {0}.")]
ReplyError(#[from] ReplyError),
#[cfg(feature = "socks4")]
#[error("Error with reply: {0}.")]
ReplySocks4Error(#[from] socks4::ReplyError),
#[error("Argument input error: `{0}`.")]
ArgumentInputError(&'static str),
// #[error("Other: `{0}`.")]
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub type Result<T, E = SocksError> = core::result::Result<T, E>;
/// SOCKS5 reply code
#[derive(Error, Debug, Copy, Clone)]
pub enum ReplyError {
#[error("Succeeded")]
Succeeded,
#[error("General failure")]
GeneralFailure,
#[error("Connection not allowed by ruleset")]
ConnectionNotAllowed,
#[error("Network unreachable")]
NetworkUnreachable,
#[error("Host unreachable")]
HostUnreachable,
#[error("Connection refused")]
ConnectionRefused,
#[error("TTL expired")]
TtlExpired,
#[error("Command not supported")]
CommandNotSupported,
#[error("Address type not supported")]
AddressTypeNotSupported,
// OtherReply(u8),
}
impl ReplyError {
#[inline]
#[rustfmt::skip]
pub fn as_u8(self) -> u8 {
match self {
ReplyError::Succeeded => consts::SOCKS5_REPLY_SUCCEEDED,
ReplyError::GeneralFailure => consts::SOCKS5_REPLY_GENERAL_FAILURE,
ReplyError::ConnectionNotAllowed => consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
ReplyError::NetworkUnreachable => consts::SOCKS5_REPLY_NETWORK_UNREACHABLE,
ReplyError::HostUnreachable => consts::SOCKS5_REPLY_HOST_UNREACHABLE,
ReplyError::ConnectionRefused => consts::SOCKS5_REPLY_CONNECTION_REFUSED,
ReplyError::TtlExpired => consts::SOCKS5_REPLY_TTL_EXPIRED,
ReplyError::CommandNotSupported => consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
ReplyError::AddressTypeNotSupported => consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
// ReplyError::OtherReply(c) => c,
}
}
#[inline]
#[rustfmt::skip]
pub fn from_u8(code: u8) -> ReplyError {
match code {
consts::SOCKS5_REPLY_SUCCEEDED => ReplyError::Succeeded,
consts::SOCKS5_REPLY_GENERAL_FAILURE => ReplyError::GeneralFailure,
consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED => ReplyError::ConnectionNotAllowed,
consts::SOCKS5_REPLY_NETWORK_UNREACHABLE => ReplyError::NetworkUnreachable,
consts::SOCKS5_REPLY_HOST_UNREACHABLE => ReplyError::HostUnreachable,
consts::SOCKS5_REPLY_CONNECTION_REFUSED => ReplyError::ConnectionRefused,
consts::SOCKS5_REPLY_TTL_EXPIRED => ReplyError::TtlExpired,
consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED => ReplyError::CommandNotSupported,
consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => ReplyError::AddressTypeNotSupported,
// _ => ReplyError::OtherReply(code),
_ => unreachable!("ReplyError code unsupported."),
}
}
}
/// Generate UDP header
///
/// # UDP Request header structure.
/// ```text
/// +----+------+------+----------+----------+----------+
/// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
/// +----+------+------+----------+----------+----------+
/// | 2 | 1 | 1 | Variable | 2 | Variable |
/// +----+------+------+----------+----------+----------+
///
/// The fields in the UDP request header are:
///
/// o RSV Reserved X'0000'
/// o FRAG Current fragment number
/// o ATYP address type of following addresses:
/// o IP V4 address: X'01'
/// o DOMAINNAME: X'03'
/// o IP V6 address: X'04'
/// o DST.ADDR desired destination address
/// o DST.PORT desired destination port
/// o DATA user data
/// ```
pub fn new_udp_header<T: ToTargetAddr>(target_addr: T) -> Result<Vec<u8>> {
let mut header = vec![
0, 0, // RSV
0, // FRAG
];
header.append(&mut target_addr.to_target_addr()?.to_be_bytes()?);
Ok(header)
}
/// Parse data from UDP client on raw buffer, return (frag, target_addr, payload).
pub async fn parse_udp_request<'a>(mut req: &'a [u8]) -> Result<(u8, TargetAddr, &'a [u8])> {
let rsv = read_exact!(req, [0u8; 2]).context("Malformed request")?;
if !rsv.eq(&[0u8; 2]) {
return Err(ReplyError::GeneralFailure.into());
}
let [frag, atyp] = read_exact!(req, [0u8; 2]).context("Malformed request")?;
let target_addr = read_address(&mut req, atyp).await.map_err(|e| {
// print explicit error
error!("{:#}", e);
// then convert it to a reply
ReplyError::AddressTypeNotSupported
})?;
Ok((frag, target_addr, req))
}
#[cfg(test)]
mod test {
use anyhow::Result;
use tokio::{
net::{TcpListener, TcpStream, UdpSocket},
sync::oneshot::Sender,
};
use crate::{
client,
server::{self, SimpleUserPassword},
};
use std::{
net::{SocketAddr, ToSocketAddrs},
num::ParseIntError,
sync::Arc,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::oneshot;
use tokio_test::block_on;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
async fn setup_socks_server(
proxy_addr: &str,
auth: Option<SimpleUserPassword>,
tx: Sender<SocketAddr>,
) -> Result<()> {
let mut config = server::Config::default();
config.set_udp_support(true);
match auth {
None => {}
Some(up) => {
config.set_authentication(up);
}
}
let config = Arc::new(config);
let listener = TcpListener::bind(proxy_addr).await?;
tx.send(listener.local_addr()?).unwrap();
loop {
let (stream, _) = listener.accept().await?;
let mut socks5_socket = server::Socks5Socket::new(stream, config.clone());
socks5_socket.set_reply_ip(proxy_addr.parse::<SocketAddr>().unwrap().ip());
socks5_socket.upgrade_to_socks5().await?;
}
}
async fn google(mut socket: TcpStream) -> Result<()> {
socket.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
let mut result = vec![];
socket.read_to_end(&mut result).await?;
println!("{}", String::from_utf8_lossy(&result));
assert!(result.starts_with(b"HTTP/1.0"));
assert!(result.ends_with(b"</HTML>\r\n") || result.ends_with(b"</html>"));
Ok(())
}
#[test]
fn google_no_auth() {
init();
block_on(async {
let (tx, rx) = oneshot::channel();
tokio::spawn(setup_socks_server("[::1]:0", None, tx));
let socket = client::Socks5Stream::connect(
rx.await.unwrap(),
"google.com".to_owned(),
80,
client::Config::default(),
)
.await
.unwrap();
google(socket.get_socket()).await.unwrap();
});
}
#[test]
fn mock_udp_assosiate_no_auth() {
init();
block_on(async {
const MOCK_ADDRESS: &str = "[::1]:40235";
let (tx, rx) = oneshot::channel();
tokio::spawn(setup_socks_server("[::1]:0", None, tx));
let backing_socket = TcpStream::connect(rx.await.unwrap()).await.unwrap();
// Creates a UDP tunnel which can be used to forward UDP packets, "[::]:0" indicates the
// binding source address used to communicate with the socks5 server.
let tunnel = client::Socks5Datagram::bind(backing_socket, "[::]:0")
.await
.unwrap();
let mock_udp_server = UdpSocket::bind(MOCK_ADDRESS).await.unwrap();
tunnel
.send_to(
b"hello world!",
MOCK_ADDRESS.to_socket_addrs().unwrap().next().unwrap(),
)
.await
.unwrap();
println!("Send packet to {}", MOCK_ADDRESS);
let mut buf = [0; 13];
let (len, addr) = mock_udp_server.recv_from(&mut buf).await.unwrap();
assert_eq!(len, 12);
assert_eq!(&buf[..12], b"hello world!");
mock_udp_server
.send_to(b"hello world!", addr)
.await
.unwrap();
println!("Recieve packet from {}", MOCK_ADDRESS);
let len = tunnel.recv_from(&mut buf).await.unwrap().0;
assert_eq!(len, 12);
assert_eq!(&buf[..12], b"hello world!");
});
}
#[test]
fn dns_udp_assosiate_no_auth() {
init();
block_on(async {
const DNS_SERVER: &str = "1.1.1.1:53";
let (tx, rx) = oneshot::channel();
tokio::spawn(setup_socks_server("[::1]:0", None, tx));
let backing_socket = TcpStream::connect(rx.await.unwrap()).await.unwrap();
// Creates a UDP tunnel which can be used to forward UDP packets, "[::]:0" indicates the
// binding source address used to communicate with the socks5 server.
let tunnel = client::Socks5Datagram::bind(backing_socket, "[::]:0")
.await
.unwrap();
#[rustfmt::skip]
tunnel.send_to(
&decode_hex(&(
"AAAA".to_owned() // ID
+ "0100" // Query parameters
+ "0001" // Number of questions
+ "0000" // Number of answers
+ "0000" // Number of authority records
+ "0000" // Number of additional records
+ "076578616d706c65"// Length + hex("example")
+ "03636f6d00" // Length + hex("com") + zero byte
+ "0001" // QTYPE
+ "0001" // QCLASS
))
.unwrap(),
DNS_SERVER.to_socket_addrs().unwrap().next().unwrap(),
).await.unwrap();
println!("Send packet to {}", DNS_SERVER);
let mut buf = [0; 128];
println!("Recieve packet from {}", DNS_SERVER);
tunnel.recv_from(&mut buf).await.unwrap();
println!("dns response {:?}", buf);
#[rustfmt::skip]
assert!(buf.starts_with(&decode_hex(&(
"AAAA".to_owned() // ID
+ "8180" // FLAGS: RCODE=0, No errors reported
+ "0001" // One question
)).unwrap()));
});
}
fn decode_hex(s: &str) -> Result<Vec<u8>, ParseIntError> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16))
.collect()
}
}

791
src/server.rs Normal file
View File

@@ -0,0 +1,791 @@
use crate::new_udp_header;
use crate::parse_udp_request;
use crate::read_exact;
use crate::ready;
use crate::util::target_addr::{read_address, TargetAddr};
use crate::Socks5Command;
use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError};
use anyhow::Context;
use std::future::Future;
use std::io;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::{SocketAddr, ToSocketAddrs as StdToSocketAddrs};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as AsyncContext, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs};
use tokio::time::timeout;
use tokio::try_join;
use tokio_stream::Stream;
#[derive(Clone)]
pub struct Config {
/// Timeout of the command request
request_timeout: u64,
/// Avoid useless roundtrips if we don't need the Authentication layer
skip_auth: bool,
/// Enable dns-resolving
dns_resolve: bool,
/// Enable command execution
execute_command: bool,
/// Enable UDP support
allow_udp: bool,
auth: Option<Arc<dyn Authentication>>,
}
impl Default for Config {
fn default() -> Self {
Config {
request_timeout: 10,
skip_auth: false,
dns_resolve: true,
execute_command: true,
allow_udp: false,
auth: None,
}
}
}
/// Use this trait to handle a custom authentication on your end.
pub trait Authentication: Send + Sync {
fn authenticate(&self, username: &str, password: &str) -> bool;
}
/// Basic user/pass auth method provided.
pub struct SimpleUserPassword {
pub username: String,
pub password: String,
}
impl Authentication for SimpleUserPassword {
fn authenticate(&self, username: &str, password: &str) -> bool {
username == &self.username && password == &self.password
}
}
impl Config {
/// How much time it should wait until the request timeout.
pub fn set_request_timeout(&mut self, n: u64) -> &mut Self {
self.request_timeout = n;
self
}
/// Skip the entire auth/handshake part, which means the server will directly wait for
/// the command request.
pub fn set_skip_auth(&mut self, value: bool) -> &mut Self {
self.skip_auth = value;
self
}
/// Enable authentication
/// 'static lifetime for Authentication avoid us to use `dyn Authentication`
/// and set the Arc before calling the function.
pub fn set_authentication<T: Authentication + 'static>(
&mut self,
authentication: T,
) -> &mut Self {
self.auth = Some(Arc::new(authentication));
self
}
/// Set whether or not to execute commands
pub fn set_execute_command(&mut self, value: bool) -> &mut Self {
self.execute_command = value;
self
}
/// Will the server perform dns resolve
pub fn set_dns_resolve(&mut self, value: bool) -> &mut Self {
self.dns_resolve = value;
self
}
/// Set whether or not to allow udp traffic
pub fn set_udp_support(&mut self, value: bool) -> &mut Self {
self.allow_udp = value;
self
}
}
/// Wrapper of TcpListener
/// Useful if you don't use any existing TcpListener's streams.
pub struct Socks5Server {
listener: TcpListener,
config: Arc<Config>,
}
impl Socks5Server {
pub async fn bind<A: AsyncToSocketAddrs>(addr: A) -> io::Result<Socks5Server> {
let listener = TcpListener::bind(&addr).await?;
let config = Arc::new(Config::default());
Ok(Socks5Server { listener, config })
}
/// Set a custom config
pub fn set_config(&mut self, config: Config) {
self.config = Arc::new(config);
}
/// Can loop on `incoming().next()` to iterate over incoming connections.
pub fn incoming(&self) -> Incoming<'_> {
Incoming(self, None)
}
}
/// `Incoming` implements [`futures::stream::Stream`].
pub struct Incoming<'a>(
&'a Socks5Server,
Option<Pin<Box<dyn Future<Output = io::Result<(TcpStream, SocketAddr)>> + Send + Sync + 'a>>>,
);
/// Iterator for each incoming stream connection
/// this wrapper will convert async_std TcpStream into Socks5Socket.
impl<'a> Stream for Incoming<'a> {
type Item = Result<Socks5Socket<TcpStream>>;
/// this code is mainly borrowed from [`Incoming::poll_next()` of `TcpListener`][tcpListener]
/// [tcpListener]: https://docs.rs/async-std/1.8.0/async_std/net/struct.TcpListener.html#method.incoming
fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.1.is_none() {
self.1 = Some(Box::pin(self.0.listener.accept()));
}
if let Some(f) = &mut self.1 {
// early returns if pending
let (socket, peer_addr) = ready!(f.as_mut().poll(cx))?;
self.1 = None;
let local_addr = socket.local_addr()?;
debug!(
"incoming connection from peer {} @ {}",
&peer_addr, &local_addr
);
// Wrap the TcpStream into Socks5Socket
let socket = Socks5Socket::new(socket, self.0.config.clone());
return Poll::Ready(Some(Ok(socket)));
}
}
}
}
/// Wrap TcpStream and contains Socks5 protocol implementation.
pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin> {
inner: T,
config: Arc<Config>,
auth: AuthenticationMethod,
target_addr: Option<TargetAddr>,
cmd: Option<Socks5Command>,
/// Socket address which will be used in the reply message.
reply_ip: Option<IpAddr>,
}
impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {
pub fn new(socket: T, config: Arc<Config>) -> Self {
Socks5Socket {
inner: socket,
config,
auth: AuthenticationMethod::None,
target_addr: None,
cmd: None,
reply_ip: None,
}
}
/// Set the bind IP address in Socks5Reply.
///
/// Only the inner socket owner knows the correct reply bind addr, so leave this field to be
/// populated. For those strict clients, users can use this function to set the correct IP
/// address.
///
/// Most popular SOCKS5 clients [1] [2] ignore BND.ADDR and BND.PORT the reply of command
/// CONNECT, but this field could be useful in some other command, such as UDP ASSOCIATE.
///
/// [1]. https://github.com/chromium/chromium/blob/bd2c7a8b65ec42d806277dd30f138a673dec233a/net/socket/socks5_client_socket.cc#L481
/// [2]. https://github.com/curl/curl/blob/d15692ebbad5e9cfb871b0f7f51a73e43762cee2/lib/socks.c#L978
pub fn set_reply_ip(&mut self, addr: IpAddr) {
self.reply_ip = Some(addr);
}
/// Process clients SOCKS requests
/// This is the entry point where a whole request is processed.
pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T>> {
trace!("upgrading to socks5...");
// Handshake
if !self.config.skip_auth {
let methods = self.get_methods().await?;
self.can_accept_method(methods).await?;
if self.config.auth.is_some() {
let credentials = self.authenticate().await?;
self.auth = AuthenticationMethod::Password {
username: credentials.0,
password: credentials.1,
};
}
} else {
debug!("skipping auth");
}
match self.request().await {
Ok(_) => {}
Err(SocksError::ReplyError(e)) => {
// If a reply error has been returned, we send it to the client
self.reply_error(&e).await?;
return Err(e.into()); // propagate the error to end this connection's task
}
// if any other errors has been detected, we simply end connection's task
Err(d) => return Err(d),
};
Ok(self)
}
/// Read the authentication method provided by the client.
/// A client send a list of methods that he supports, he could send
///
/// - 0: Non auth
/// - 2: Auth with username/password
///
/// Altogether, then the server choose to use of of these,
/// or deny the handshake (thus the connection).
///
/// # Examples
/// ```text
/// {SOCKS Version, methods-length}
/// eg. (non-auth) {5, 2}
/// eg. (auth) {5, 3}
/// ```
///
async fn get_methods(&mut self) -> Result<Vec<u8>> {
trace!("Socks5Socket: get_methods()");
// read the first 2 bytes which contains the SOCKS version and the methods len()
let [version, methods_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read methods")?;
debug!(
"Handshake headers: [version: {version}, methods len: {len}]",
version = version,
len = methods_len,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
// {METHODS available from the client}
// eg. (non-auth) {0, 1}
// eg. (auth) {0, 1, 2}
let methods = read_exact!(self.inner, vec![0u8; methods_len as usize])
.context("Can't get methods.")?;
debug!("methods supported sent by the client: {:?}", &methods);
// Return methods available
Ok(methods)
}
/// Decide to whether or not, accept the authentication method.
/// Don't forget that the methods list sent by the client, contains one or more methods.
///
/// # Request
///
/// Client send an array of 3 entries: [0, 1, 2]
/// ```text
/// {SOCKS Version, Authentication chosen}
/// eg. (non-auth) {5, 0}
/// eg. (GSSAPI) {5, 1}
/// eg. (auth) {5, 2}
/// ```
///
/// # Response
/// ```text
/// eg. (accept non-auth) {5, 0x00}
/// eg. (non-acceptable) {5, 0xff}
/// ```
///
async fn can_accept_method(&mut self, client_methods: Vec<u8>) -> Result<()> {
let method_supported;
if self.config.auth.is_some() {
method_supported = consts::SOCKS5_AUTH_METHOD_PASSWORD;
} else {
method_supported = consts::SOCKS5_AUTH_METHOD_NONE;
}
if !client_methods.contains(&method_supported) {
debug!("Don't support this auth method, reply with (0xff)");
self.inner
.write(&[
consts::SOCKS5_VERSION,
consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
])
.await
.context("Can't reply with method not acceptable.")?;
return Err(SocksError::AuthMethodUnacceptable(client_methods));
}
debug!(
"Reply with method {} ({})",
AuthenticationMethod::from_u8(method_supported).context("Method not supported")?,
method_supported
);
self.inner
.write(&[consts::SOCKS5_VERSION, method_supported])
.await
.context("Can't reply with method auth-none")?;
Ok(())
}
/// Only called if
/// - the client supports authentication via username/password
/// - this server has `Authentication` trait implemented.
async fn authenticate(&mut self) -> Result<(String, String)> {
trace!("Socks5Socket: authenticate()");
let [version, user_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read user len")?;
debug!(
"Auth: [version: {version}, user len: {len}]",
version = version,
len = user_len,
);
if user_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Username malformed ({} chars)",
user_len
)));
}
let username =
read_exact!(self.inner, vec![0u8; user_len as usize]).context("Can't get username.")?;
debug!("username bytes: {:?}", &username);
let [pass_len] = read_exact!(self.inner, [0u8; 1]).context("Can't read pass len")?;
debug!("Auth: [pass len: {len}]", len = pass_len,);
if pass_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Password malformed ({} chars)",
pass_len
)));
}
let password =
read_exact!(self.inner, vec![0u8; pass_len as usize]).context("Can't get password.")?;
debug!("password bytes: {:?}", &password);
let username = String::from_utf8(username).context("Failed to convert username")?;
let password = String::from_utf8(password).context("Failed to convert password")?;
let auth = self.config.auth.as_ref().context("No auth module")?;
if auth.authenticate(&username, &password) {
self.inner
.write(&[1, consts::SOCKS5_REPLY_SUCCEEDED])
.await
.context("Can't reply auth success")?;
} else {
self.inner
.write(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
.await
.context("Can't reply with auth method not acceptable.")?;
return Err(SocksError::AuthenticationRejected(format!(
"Authentication with username `{}`, rejected.",
username
)));
}
info!("User `{}` logged successfully.", username);
Ok((username, password))
}
/// Wrapper to principally cover ReplyError types for both functions read & execute request.
async fn request(&mut self) -> Result<()> {
self.read_command().await?;
if self.config.dns_resolve {
self.resolve_dns().await?;
} else {
debug!("Domain won't be resolved because `dns_resolve`'s config has been turned off.")
}
if self.config.execute_command {
self.execute_command().await?;
}
Ok(())
}
/// Reply error to the client with the reply code according to the RFC.
async fn reply_error(&mut self, error: &ReplyError) -> Result<()> {
let reply = new_reply(error, "0.0.0.0:0".parse().unwrap());
debug!("reply error to be written: {:?}", &reply);
self.inner
.write(&reply)
.await
.context("Can't write the reply!")?;
self.inner.flush().await.context("Can't flush the reply!")?;
Ok(())
}
/// Decide to whether or not, accept the authentication method.
/// Don't forget that the methods list sent by the client, contains one or more methods.
///
/// # Request
/// ```text
/// +----+-----+-------+------+----------+----------+
/// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
/// +----+-----+-------+------+----------+----------+
/// | 1 | 1 | 1 | 1 | Variable | 2 |
/// +----+-----+-------+------+----------+----------+
/// ```
///
/// It the request is correct, it should returns a ['SocketAddr'].
///
async fn read_command(&mut self) -> Result<()> {
let [version, cmd, rsv, address_type] =
read_exact!(self.inner, [0u8; 4]).context("Malformed request")?;
debug!(
"Request: [version: {version}, command: {cmd}, rev: {rsv}, address_type: {address_type}]",
version = version,
cmd = cmd,
rsv = rsv,
address_type = address_type,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
match Socks5Command::from_u8(cmd) {
None => return Err(ReplyError::CommandNotSupported.into()),
Some(cmd) => match cmd {
Socks5Command::TCPConnect => {
self.cmd = Some(cmd);
}
Socks5Command::UDPAssociate => {
if !self.config.allow_udp {
return Err(ReplyError::CommandNotSupported.into());
}
self.cmd = Some(cmd);
}
Socks5Command::TCPBind => return Err(ReplyError::CommandNotSupported.into()),
},
}
// Guess address type
let target_addr = read_address(&mut self.inner, address_type)
.await
.map_err(|e| {
// print explicit error
error!("{:#}", e);
// then convert it to a reply
ReplyError::AddressTypeNotSupported
})?;
self.target_addr = Some(target_addr);
debug!("Request target is {}", self.target_addr.as_ref().unwrap());
Ok(())
}
/// This function is public, it can be call manually on your own-willing
/// if config flag has been turned off: `Config::dns_resolve == false`.
pub async fn resolve_dns(&mut self) -> Result<()> {
trace!("resolving dns");
if let Some(target_addr) = self.target_addr.take() {
// decide whether we have to resolve DNS or not
self.target_addr = match target_addr {
TargetAddr::Domain(_, _) => Some(target_addr.resolve_dns().await?),
TargetAddr::Ip(_) => Some(target_addr),
};
}
Ok(())
}
/// Execute the socks5 command that the client wants.
async fn execute_command(&mut self) -> Result<()> {
match &self.cmd {
None => Err(ReplyError::CommandNotSupported.into()),
Some(cmd) => match cmd {
Socks5Command::TCPBind => Err(ReplyError::CommandNotSupported.into()),
Socks5Command::TCPConnect => return self.execute_command_connect().await,
Socks5Command::UDPAssociate => {
if self.config.allow_udp {
return self.execute_command_udp_assoc().await;
} else {
Err(ReplyError::CommandNotSupported.into())
}
}
},
}
}
/// Connect to the target address that the client wants,
/// then forward the data between them (client <=> target address).
async fn execute_command_connect(&mut self) -> Result<()> {
// async-std's ToSocketAddrs doesn't supports external trait implementation
// @see https://github.com/async-rs/async-std/issues/539
let addr = self
.target_addr
.as_ref()
.context("target_addr empty")?
.to_socket_addrs()?
.next()
.context("unreachable")?;
let fut = TcpStream::connect(addr);
let limit = Duration::from_secs(self.config.request_timeout);
// TCP connect with timeout, to avoid memory leak for connection that takes forever
let outbound = match timeout(limit, fut).await {
Ok(e) => match e {
Ok(o) => o,
Err(e) => match e.kind() {
// Match other TCP errors with ReplyError
io::ErrorKind::ConnectionRefused => {
return Err(ReplyError::ConnectionRefused.into())
}
io::ErrorKind::ConnectionAborted => {
return Err(ReplyError::ConnectionNotAllowed.into())
}
io::ErrorKind::ConnectionReset => {
return Err(ReplyError::ConnectionNotAllowed.into())
}
io::ErrorKind::NotConnected => {
return Err(ReplyError::NetworkUnreachable.into())
}
_ => return Err(e.into()), // #[error("General failure")] ?
},
},
// Wrap timeout error in a proper ReplyError
Err(_) => return Err(ReplyError::TtlExpired.into()),
};
debug!("Connected to remote destination");
self.inner
.write(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
))
.await
.context("Can't write successful reply")?;
self.inner.flush().await.context("Can't flush the reply!")?;
debug!("Wrote success");
transfer(&mut self.inner, outbound).await
}
/// Bind to a random UDP port, wait for the traffic from
/// the client, and then forward the data to the remote addr.
async fn execute_command_udp_assoc(&mut self) -> Result<()> {
// The DST.ADDR and DST.PORT fields contain the address and port that
// the client expects to use to send UDP datagrams on for the
// association. The server MAY use this information to limit access
// to the association.
// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
//
// We do NOT limit the access from the client currently in this implementation.
let _not_used = self.target_addr.as_ref();
// Listen with UDP6 socket, so the client can connect to it with either
// IPv4 or IPv6.
let peer_sock = UdpSocket::bind("[::]:0").await?;
// Respect the pre-populated reply IP address.
self.inner
.write(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(
self.reply_ip.context("invalid reply ip")?,
peer_sock.local_addr()?.port(),
),
))
.await
.context("Can't write successful reply")?;
debug!("Wrote success");
transfer_udp(peer_sock).await?;
Ok(())
}
pub fn target_addr(&self) -> Option<&TargetAddr> {
self.target_addr.as_ref()
}
pub fn auth(&self) -> &AuthenticationMethod {
&self.auth
}
}
/// Copy data between two peers
/// Using 2 different generators, because they could be different structs with same traits.
async fn transfer<I, O>(mut inbound: I, mut outbound: O) -> Result<()>
where
I: AsyncRead + AsyncWrite + Unpin,
O: AsyncRead + AsyncWrite + Unpin,
{
match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await {
Ok(res) => info!("transfer closed ({}, {})", res.0, res.1),
Err(err) => error!("transfer error: {:?}", err),
};
Ok(())
}
async fn handle_udp_request(inbound: &UdpSocket, outbound: &UdpSocket) -> Result<()> {
let mut buf = vec![0u8; 0x10000];
loop {
let (size, client_addr) = inbound.recv_from(&mut buf).await?;
debug!("Server recieve udp from {}", client_addr);
inbound.connect(client_addr).await?;
let (frag, target_addr, data) = parse_udp_request(&buf[..size]).await?;
if frag != 0 {
debug!("Discard UDP frag packets sliently.");
return Ok(());
}
debug!("Server forward to packet to {}", target_addr);
let mut target_addr = target_addr
.to_socket_addrs()?
.next()
.context("unreachable")?;
target_addr.set_ip(match target_addr.ip() {
std::net::IpAddr::V4(v4) => std::net::IpAddr::V6(v4.to_ipv6_mapped()),
v6 @ std::net::IpAddr::V6(_) => v6,
});
outbound.send_to(data, target_addr).await?;
}
}
async fn handle_udp_response(inbound: &UdpSocket, outbound: &UdpSocket) -> Result<()> {
let mut buf = vec![0u8; 0x10000];
loop {
let (size, remote_addr) = outbound.recv_from(&mut buf).await?;
debug!("Recieve packet from {}", remote_addr);
let mut data = new_udp_header(remote_addr)?;
data.extend_from_slice(&buf[..size]);
inbound.send(&data).await?;
}
}
async fn transfer_udp(inbound: UdpSocket) -> Result<()> {
let outbound = UdpSocket::bind("[::]:0").await?;
let req_fut = handle_udp_request(&inbound, &outbound);
let res_fut = handle_udp_response(&inbound, &outbound);
match try_join!(req_fut, res_fut) {
Ok(_) => {}
Err(error) => return Err(error),
}
Ok(())
}
/// Allow us to read directly from the struct
impl<T> AsyncRead for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(context, buf)
}
}
/// Allow us to write directly into the struct
impl<T> AsyncWrite for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(context)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(context)
}
}
/// Generate reply code according to the RFC.
fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
let (addr_type, mut ip_oct, mut port) = match sock_addr {
SocketAddr::V4(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV4,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
SocketAddr::V6(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV6,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
};
let mut reply = vec![
consts::SOCKS5_VERSION,
error.as_u8(), // transform the error into byte code
0x00, // reserved
addr_type, // address type (ipv4, v6, domain)
];
reply.append(&mut ip_oct);
reply.append(&mut port);
reply
}
#[cfg(test)]
mod test {
use crate::server::Socks5Server;
use tokio_test::block_on;
#[test]
fn test_bind() {
let f = async {
let _server = Socks5Server::bind("127.0.0.1:1080").await.unwrap();
};
block_on(f);
}
}

363
src/socks4/client.rs Normal file
View File

@@ -0,0 +1,363 @@
#[forbid(unsafe_code)]
use crate::read_exact;
use crate::socks4::{consts, ReplyError, Socks4Command};
use crate::util::target_addr::{TargetAddr, ToTargetAddr};
use crate::{Result, SocksError, SocksError::ReplySocks4Error};
use anyhow::Context;
use std::io;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use std::task::Poll;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
const MAX_ADDR_LEN: usize = 260;
/// A SOCKS4 client.
/// `Socks4Stream` implements [`AsyncRead`] and [`AsyncWrite`].
#[derive(Debug)]
pub struct Socks4Stream<S: AsyncRead + AsyncWrite + Unpin> {
socket: S,
target_addr: Option<TargetAddr>,
}
impl<S> Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
/// Possibility to use a stream already created rather than
/// creating a whole new `TcpStream::connect()`.
pub fn use_stream(socket: S) -> Result<Self> {
let stream = Socks4Stream {
socket,
target_addr: None,
};
Ok(stream)
}
/// https://www.openssh.com/txt/socks4.protocol
/// https://www.openssh.com/txt/socks4a.protocol
///
/// 1) CONNECT
///
/// +----+----+----+----+----+----+----+----+----+----+....+----+
/// | VN | CD | DSTPORT | DSTIP | USERID |NULL|
/// +----+----+----+----+----+----+----+----+----+----+....+----+
/// #of bytes 1 1 2 4 variable 1
///
/// VN is the SOCKS protocol version number and should be 4. CD is the
/// SOCKS command code and should be 1 for CONNECT request. NULL is a byte
/// of all zero bits.
///
/// The SOCKS server checks to see whether such a request should be granted
/// based on any combination of source IP address, destination IP address,
/// destination port number, the userid, and information it may obtain by
/// consulting IDENT, cf. RFC 1413. If the request is granted, the SOCKS
/// server makes a connection to the specified port of the destination host.
/// A reply packet is sent to the client when this connection is established,
/// or when the request is rejected or the operation fails.
///
/// Response:
///
/// +----+----+----+----+----+----+----+----+
/// | VN | CD | DSTPORT | DSTIP |
/// +----+----+----+----+----+----+----+----+
/// #of bytes 1 1 2 4
///
/// VN is the version of the reply code and should be 0. CD is the result
/// code with one of the following values:
///
/// 90: request granted
/// 91: request rejected or failed
/// 92: request rejected because SOCKS server cannot connect to
/// identd on the client
/// 93: request rejected because the client program and identd
/// report different user-ids
///
pub async fn request(
&mut self,
cmd: Socks4Command,
target_addr: TargetAddr,
resolve_locally: bool,
) -> Result<()> {
let resolved = if target_addr.is_domain() && resolve_locally {
target_addr.resolve_dns().await?
} else {
target_addr
};
self.target_addr = Some(resolved);
self.send_command_request(&cmd).await?;
self.read_command_request().await?;
Ok(())
}
async fn send_command_request(&mut self, cmd: &Socks4Command) -> Result<()> {
let mut packet = [0u8; MAX_ADDR_LEN];
packet[0] = consts::SOCKS4_VERSION;
packet[1] = cmd.as_u8();
match &self.target_addr {
Some(TargetAddr::Ip(SocketAddr::V4(addr))) => {
packet[2] = (addr.port() >> 8) as u8;
packet[3] = addr.port() as u8;
packet[4..8].copy_from_slice(&(addr.ip()).octets());
Ok(())
}
Some(TargetAddr::Ip(SocketAddr::V6(addr))) => {
error!("IPv6 are not supported: {:?}", addr);
Err(ReplySocks4Error(ReplyError::AddressTypeNotSupported))
}
Some(TargetAddr::Domain(domain, port)) => {
packet[2] = (port >> 8) as u8;
packet[3] = *port as u8;
packet[4..8].copy_from_slice(&[0, 0, 0, 1]);
let domain_bytes = domain.as_bytes();
let offset = 8 + domain_bytes.len();
packet[8..offset].copy_from_slice(domain_bytes);
Ok(())
}
_ => {
panic!("Unreachable case");
}
}?;
self.socket.write_all(&packet).await?;
Ok(())
}
#[rustfmt::skip]
async fn read_command_request(&mut self) -> Result<()> {
let [_, cd] = read_exact!(self.socket, [0u8; 2])?;
let reply = ReplyError::from_u8(cd);
match reply {
ReplyError::Succeeded => Ok(()),
_ => Err(SocksError::ReplySocks4Error(reply))
}
}
pub fn get_socket(self) -> S {
self.socket
}
pub fn get_socket_ref(&self) -> &S {
&self.socket
}
pub fn get_socket_mut(&mut self) -> &mut S {
&mut self.socket
}
}
/// Api if you want to use TcpStream to create a new connection to the SOCKS4 server.
impl Socks4Stream<TcpStream> {
/// Connects to a target server through a SOCKS4 proxy.
pub async fn connect<T>(
socks_server: T,
target_addr: String,
target_port: u16,
resolve_locally: bool,
) -> Result<Self>
where
T: ToSocketAddrs,
{
Self::connect_raw(
Socks4Command::Connect,
socks_server,
target_addr,
target_port,
resolve_locally,
)
.await
}
/// Process clients SOCKS requests
/// This is the entry point where a whole request is processed.
pub async fn connect_raw<T>(
cmd: Socks4Command,
socks_server: T,
target_addr: String,
target_port: u16,
resolve_locally: bool,
) -> Result<Self>
where
T: ToSocketAddrs,
{
let socket = TcpStream::connect(
socks_server
.to_socket_addrs()?
.next()
.context("unreachable")?,
)
.await?;
info!("Connected @ {}", &socket.peer_addr()?);
// Specify the target, here domain name, dns will be resolved on the server side
let target_addr = (target_addr.as_str(), target_port)
.to_target_addr()
.context("Can't convert address to TargetAddr format")?;
// upgrade the TcpStream to Socks4Stream
let mut socks_stream = Self::use_stream(socket)?;
socks_stream
.request(cmd, target_addr, resolve_locally)
.await?;
Ok(socks_stream)
}
}
/// Allow us to read directly from the struct
impl<S> AsyncRead for Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.socket).poll_read(context, buf)
}
}
/// Allow us to write directly into the struct
impl<S> AsyncWrite for Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.socket).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_flush(context)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_shutdown(context)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn get_domain() -> String {
"www.google.com".to_string()
}
async fn get_humans_txt(socks: &mut Socks4Stream<TcpStream>) -> Option<String> {
let headers = format!(
"GET /humans.txt HTTP/1.1\r\n\
Host: {}\r\n\
User-Agent: fast-socks5/0.1.0\r\n\
Accept: */*\r\n\r\n",
get_domain()
);
socks
.write_all(headers.as_bytes())
.await
.expect("should successfully write");
let response = &mut [0u8; 2048];
socks
.read(response)
.await
.expect("should successfully read");
// sometimes google returns body on second request
if response[0] == 0 {
response.copy_from_slice(&[0u8; 2048]);
socks
.read(response)
.await
.expect("should successfully read");
}
let response_str = String::from_utf8_lossy(response);
let response_body = response_str
.split("\n")
.into_iter()
.filter(|x| x.starts_with("Google"))
.last()
.map(|x| x.to_string());
response_body
}
fn assert_response_body(response_body: &String) {
let expected =
"Google is built by a large team of engineers, designers, researchers, robots, \
and others in many different sites across the globe. It is updated continuously, \
and built with more tools and technologies than we can shake a stick at. If you'd \
like to help us out, see careers.google.com.";
assert_eq!(expected, response_body);
}
#[tokio::test]
pub async fn test_use_stream() {
// TODO: replace with local socks4 server
// it requires implementation
let tcp = TcpStream::connect("217.17.56.160:4145")
.await
.expect("should connect to remote");
let mut socks = Socks4Stream::use_stream(tcp).expect("should wrap to socks stream");
socks
.request(
Socks4Command::Connect,
TargetAddr::Domain(get_domain(), 80),
true,
)
.await
.expect("should send connect successfully");
let response_body = get_humans_txt(&mut socks)
.await
.expect("should have response_body");
assert_response_body(&response_body);
}
#[tokio::test]
pub async fn test_use_stream_local_resolve() {
let mut socks = Socks4Stream::connect("217.17.56.160:4145", get_domain(), 80, true)
.await
.expect("should connect successfully to socks4 server");
let response_body = get_humans_txt(&mut socks)
.await
.expect("should have response_body");
assert_response_body(&response_body);
}
// Need to find socks4a supporting proxy or implement
// custom server and test using it
//
// #[tokio::test]
// pub async fn test_use_stream_remote_resolve() {
// let mut socks = Socks4Stream::connect("217.17.56.160:4145", get_domain(), 80, false)
// .await
// .expect("should connect successfully to socks4 server");
//
// let response_body = get_humans_txt(&mut socks)
// .await
// .expect("should have response_body");
//
// assert_response_body(&response_body);
// }
}

88
src/socks4/mod.rs Normal file
View File

@@ -0,0 +1,88 @@
pub mod client;
use thiserror::Error;
#[rustfmt::skip]
pub mod consts {
pub const SOCKS4_VERSION: u8 = 0x04;
pub const SOCKS4_CMD_CONNECT: u8 = 0x01;
pub const SOCKS4_CMD_BIND: u8 = 0x02;
pub const SOCKS4_REPLY_SUCCEEDED: u8 = 0x5a;
pub const SOCKS4_REPLY_FAILED: u8 = 0x5b;
pub const SOCKS4_REPLY_HOST_UNREACHABLE: u8 = 0x5c;
pub const SOCKS4_REPLY_INVALID_USER: u8 = 0x5d;
}
/// SOCKS4 reply code
#[derive(Error, Debug, Copy, Clone)]
pub enum ReplyError {
#[error("Succeeded")]
Succeeded,
#[error("General failure")]
GeneralFailure,
#[error("Host unreachable")]
HostUnreachable,
#[error("Address type not supported")]
AddressTypeNotSupported,
#[error("Invalid user")]
InvalidUser,
#[error("Unknown response")]
UnknownResponse(u8),
}
#[derive(Debug, PartialEq)]
pub enum Socks4Command {
Connect,
Bind,
}
#[allow(dead_code)]
impl Socks4Command {
#[inline]
#[rustfmt::skip]
pub fn as_u8(&self) -> u8 {
match self {
Socks4Command::Connect => consts::SOCKS4_CMD_CONNECT,
Socks4Command::Bind => consts::SOCKS4_CMD_BIND,
}
}
#[inline]
#[rustfmt::skip]
pub fn from_u8(code: u8) -> Option<Socks4Command> {
match code {
consts::SOCKS4_CMD_CONNECT => Some(Socks4Command::Connect),
consts::SOCKS4_CMD_BIND => Some(Socks4Command::Bind),
_ => None,
}
}
}
impl ReplyError {
#[inline]
#[rustfmt::skip]
pub fn as_u8(self) -> u8 {
match self {
ReplyError::Succeeded => consts::SOCKS4_REPLY_SUCCEEDED,
ReplyError::GeneralFailure => consts::SOCKS4_REPLY_FAILED,
ReplyError::HostUnreachable => consts::SOCKS4_REPLY_HOST_UNREACHABLE,
ReplyError::InvalidUser => consts::SOCKS4_REPLY_INVALID_USER,
reply => panic!("Unsupported ReplyStatus: {:?}", reply)
}
}
#[inline]
#[rustfmt::skip]
pub fn from_u8(code: u8) -> ReplyError {
match code {
consts::SOCKS4_REPLY_SUCCEEDED => ReplyError::Succeeded,
consts::SOCKS4_REPLY_FAILED => ReplyError::GeneralFailure,
consts::SOCKS4_REPLY_HOST_UNREACHABLE => ReplyError::HostUnreachable,
consts::SOCKS4_REPLY_INVALID_USER => ReplyError::InvalidUser,
_ => ReplyError::UnknownResponse(code),
}
}
}

2
src/util/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod stream;
pub mod target_addr;

40
src/util/stream.rs Normal file
View File

@@ -0,0 +1,40 @@
/// Easy to destructure bytes buffers by naming each fields:
///
/// # Examples (before)
///
/// ```ignore
/// let mut buf = [0u8; 2];
/// stream.read_exact(&mut buf).await?;
/// let [version, method_len] = buf;
///
/// assert_eq!(version, 0x05);
/// ```
///
/// # Examples (after)
///
/// ```ignore
/// let [version, method_len] = read_exact!(stream, [0u8; 2]);
///
/// assert_eq!(version, 0x05);
/// ```
#[macro_export]
macro_rules! read_exact {
($stream: expr, $array: expr) => {{
let mut x = $array;
// $stream
// .read_exact(&mut x)
// .await
// .map_err(|_| io_err("lol"))?;
$stream.read_exact(&mut x).await.map(|_| x)
}};
}
#[macro_export]
macro_rules! ready {
($e:expr $(,)?) => {
match $e {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}

241
src/util/target_addr.rs Normal file
View File

@@ -0,0 +1,241 @@
use crate::consts;
use crate::consts::SOCKS5_ADDR_TYPE_IPV4;
use crate::read_exact;
use crate::SocksError;
use anyhow::Context;
use std::fmt;
use std::io;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::vec::IntoIter;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::net::lookup_host;
/// SOCKS5 reply code
#[derive(Error, Debug)]
pub enum AddrError {
#[error("DNS Resolution failed")]
DNSResolutionFailed,
#[error("Can't read IPv4")]
IPv4Unreadable,
#[error("Can't read IPv6")]
IPv6Unreadable,
#[error("Can't read port number")]
PortNumberUnreadable,
#[error("Can't read domain len")]
DomainLenUnreadable,
#[error("Can't read Domain content")]
DomainContentUnreadable,
#[error("Malformed UTF-8")]
Utf8,
#[error("Unknown address type")]
IncorrectAddressType,
#[error("{0}")]
Custom(String),
}
/// A description of a connection target.
#[derive(Debug, Clone)]
pub enum TargetAddr {
/// Connect to an IP address.
Ip(SocketAddr),
/// Connect to a fully qualified domain name.
///
/// The domain name will be passed along to the proxy server and DNS lookup
/// will happen there.
Domain(String, u16),
}
impl TargetAddr {
pub async fn resolve_dns(self) -> anyhow::Result<TargetAddr> {
match self {
TargetAddr::Ip(ip) => Ok(TargetAddr::Ip(ip)),
TargetAddr::Domain(domain, port) => {
debug!("Attempt to DNS resolve the domain {}...", &domain);
let socket_addr = lookup_host((&domain[..], port))
.await
.context(AddrError::DNSResolutionFailed)?
.next()
.ok_or(AddrError::Custom(
"Can't fetch DNS to the domain.".to_string(),
))?;
debug!("domain name resolved to {}", socket_addr);
// has been converted to an ip
Ok(TargetAddr::Ip(socket_addr))
}
}
}
pub fn is_ip(&self) -> bool {
match self {
TargetAddr::Ip(_) => true,
_ => false,
}
}
pub fn is_domain(&self) -> bool {
!self.is_ip()
}
pub fn to_be_bytes(&self) -> anyhow::Result<Vec<u8>> {
let mut buf = vec![];
match self {
TargetAddr::Ip(SocketAddr::V4(addr)) => {
debug!("TargetAddr::IpV4");
buf.extend_from_slice(&[SOCKS5_ADDR_TYPE_IPV4]);
debug!("addr ip {:?}", (*addr.ip()).octets());
buf.extend_from_slice(&(addr.ip()).octets()); // ip
buf.extend_from_slice(&addr.port().to_be_bytes()); // port
}
TargetAddr::Ip(SocketAddr::V6(addr)) => {
debug!("TargetAddr::IpV6");
buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_IPV6]);
debug!("addr ip {:?}", (*addr.ip()).octets());
buf.extend_from_slice(&(addr.ip()).octets()); // ip
buf.extend_from_slice(&addr.port().to_be_bytes()); // port
}
TargetAddr::Domain(ref domain, port) => {
debug!("TargetAddr::Domain");
if domain.len() > u8::max_value() as usize {
return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
}
buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
buf.extend_from_slice(domain.as_bytes()); // domain content
buf.extend_from_slice(&port.to_be_bytes());
// port content (.to_be_bytes() convert from u16 to u8 type)
}
}
Ok(buf)
}
}
// async-std ToSocketAddrs doesn't supports external trait implementation
// @see https://github.com/async-rs/async-std/issues/539
impl std::net::ToSocketAddrs for TargetAddr {
type Iter = IntoIter<SocketAddr>;
fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
match *self {
TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
TargetAddr::Domain(_, _) => Err(io::Error::new(
io::ErrorKind::Other,
"Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
)),
}
}
}
impl fmt::Display for TargetAddr {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TargetAddr::Ip(ref addr) => write!(f, "{}", addr),
TargetAddr::Domain(ref addr, ref port) => write!(f, "{}:{}", addr, port),
}
}
}
/// A trait for objects that can be converted to `TargetAddr`.
pub trait ToTargetAddr {
/// Converts the value of `self` to a `TargetAddr`.
fn to_target_addr(&self) -> io::Result<TargetAddr>;
}
impl<'a> ToTargetAddr for (&'a str, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
// try to parse as an IP first
if let Ok(addr) = self.0.parse::<Ipv4Addr>() {
return (addr, self.1).to_target_addr();
}
if let Ok(addr) = self.0.parse::<Ipv6Addr>() {
return (addr, self.1).to_target_addr();
}
Ok(TargetAddr::Domain(self.0.to_owned(), self.1))
}
}
impl ToTargetAddr for SocketAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
Ok(TargetAddr::Ip(*self))
}
}
impl ToTargetAddr for SocketAddrV4 {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddr::V4(*self).to_target_addr()
}
}
impl ToTargetAddr for SocketAddrV6 {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddr::V6(*self).to_target_addr()
}
}
impl ToTargetAddr for (Ipv4Addr, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddrV4::new(self.0, self.1).to_target_addr()
}
}
impl ToTargetAddr for (Ipv6Addr, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddrV6::new(self.0, self.1, 0, 0).to_target_addr()
}
}
#[derive(Debug)]
pub enum Addr {
V4([u8; 4]),
V6([u8; 16]),
Domain(String), // Vec<[u8]> or Box<[u8]> or String ?
}
/// This function is used by the client & the server
pub async fn read_address<T: AsyncRead + Unpin>(
stream: &mut T,
atyp: u8,
) -> anyhow::Result<TargetAddr> {
let addr = match atyp {
consts::SOCKS5_ADDR_TYPE_IPV4 => {
debug!("Address type `IPv4`");
Addr::V4(read_exact!(stream, [0u8; 4]).context(AddrError::IPv4Unreadable)?)
}
consts::SOCKS5_ADDR_TYPE_IPV6 => {
debug!("Address type `IPv6`");
Addr::V6(read_exact!(stream, [0u8; 16]).context(AddrError::IPv6Unreadable)?)
}
consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
debug!("Address type `domain`");
let len = read_exact!(stream, [0]).context(AddrError::DomainLenUnreadable)?[0];
let domain = read_exact!(stream, vec![0u8; len as usize])
.context(AddrError::DomainContentUnreadable)?;
// make sure the bytes are correct utf8 string
let domain = String::from_utf8(domain).context(AddrError::Utf8)?;
Addr::Domain(domain)
}
_ => return Err(anyhow::anyhow!(AddrError::IncorrectAddressType)),
};
// Find port number
let port = read_exact!(stream, [0u8; 2]).context(AddrError::PortNumberUnreadable)?;
// Convert (u8 * 2) into u16
let port = (port[0] as u16) << 8 | port[1] as u16;
// Merge ADDRESS + PORT into a TargetAddr
let addr: TargetAddr = match addr {
Addr::V4([a, b, c, d]) => (Ipv4Addr::new(a, b, c, d), port).to_target_addr()?,
Addr::V6(x) => (Ipv6Addr::from(x), port).to_target_addr()?,
Addr::Domain(domain) => TargetAddr::Domain(domain, port),
};
Ok(addr)
}