Files
proxy-inspector/src/app.rs
2024-03-30 15:37:26 +08:00

171 lines
5.4 KiB
Rust

use async_trait::async_trait;
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
use bytes::Bytes;
use http::HeaderName;
use pingora::{Error, ErrorType};
use pingora::http::ResponseHeader;
use pingora::prelude::{HttpPeer, ProxyHttp, Result, Session};
use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::proto::rr::RData;
use trust_dns_resolver::TokioAsyncResolver;
use super::service::HostConfig;
pub struct ProxyApp {
tls: bool,
lookup_dns: bool,
host_configs: Vec<HostConfig>,
tokio_async_resolver: TokioAsyncResolver,
}
impl ProxyApp {
pub fn new(tls: bool, lookup_dns: bool, host_configs: Vec<HostConfig>) -> Self {
let tokio_async_resolver = TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
);
ProxyApp {
tls,
lookup_dns,
host_configs,
tokio_async_resolver,
}
}
async fn lookup_ipv4(&self, hostname: &str) -> Option<String> {
let ips = self.tokio_async_resolver.ipv4_lookup(hostname).await;
log::debug!("lookup {} --> {:#?}", hostname, ips);
match ips {
Ok(ips) => {
let records = ips.as_lookup().records();
if records.len() > 0 {
let record = &records[0];
if let Some(rdata) = record.data() {
match rdata {
RData::A(a) => {
return Some(a.0.to_string());
}
_ => {}
}
}
}
}
Err(_) => {}
}
None
}
}
#[async_trait]
impl ProxyHttp for ProxyApp {
type CTX = ();
fn new_ctx(&self) {}
async fn upstream_peer(&self, session: &mut Session, _ctx: &mut ()) -> Result<Box<HttpPeer>> {
let host_header = session
.get_header(HeaderName::from_static("host"))
.unwrap()
.to_str()
.expect("get host from http header failed");
log::info!("host header: {host_header}");
let hostname = if host_header.contains(':') {
host_header.chars().take_while(|c| c != &':').collect()
} else {
host_header.to_string()
};
if hostname == "localhost" {
return Err(Error::new(ErrorType::CustomCode("bad host", 400)));
}
let host_config = self
.host_configs
.iter()
.find(|x| x.proxy_hostname == hostname);
if let Some(host_config) = host_config {
let peer = HttpPeer::new(
host_config.proxy_addr.as_str(),
host_config.proxy_tls,
host_config.proxy_hostname.clone(),
);
log::info!("Find peer: {} --> {:?}", hostname, host_config.proxy_addr);
return Ok(Box::new(peer));
}
if self.lookup_dns {
if let Some(address) = self.lookup_ipv4(&hostname).await {
let peer = HttpPeer::new(
format!("{}:{}", address, if self.tls { 443 } else { 80 }),
self.tls,
hostname.to_string(),
);
log::info!("Find peer: {} --> {:?}", hostname, address);
return Ok(Box::new(peer));
}
}
panic!("Cannot find peer: {}", hostname);
}
async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool>
where Self::CTX: Send + Sync,
{
let request_header = session.req_header();
let mut req = String::with_capacity(512);
req.push_str(request_header.method.as_str());
req.push(' ');
req.push_str(&request_header.uri.to_string());
req.push(' ');
req.push_str(&format!("{:?}\n", request_header.version));
let header_len = request_header.headers.len();
request_header.headers.iter().enumerate().for_each(|(i, (n, v))| {
req.push_str(
&format!("{}: {}{}",
n.as_str(),
v.to_str().unwrap_or("ERROR!BAD-VALUE!"),
if i < header_len - 1 { "\n" } else { "" }
)
);
});
let body = match session.read_request_body().await {
Ok(Some(body_bytes)) => Some(STANDARD.encode(body_bytes)),
_ => None,
};
log::info!("Request:\n{}\n\n{}", req, body.unwrap_or_else(|| "<None>".into()));
Ok(false)
}
async fn response_filter(
&self,
_session: &mut Session,
upstream_response: &mut ResponseHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where Self::CTX: Send + Sync,
{
let mut resp = String::new();
resp.push_str(&format!("version: {}\n", upstream_response.status));
resp.push_str(&format!("headers: {:#?}", upstream_response.headers));
log::info!("Response: {}", resp);
Ok(())
}
fn upstream_response_body_filter(
&self,
_session: &mut Session,
body: &Option<Bytes>,
end_of_stream: bool,
_ctx: &mut Self::CTX,
) {
log::info!("Body {}: [[[{}]]] ",end_of_stream,
body.as_ref().map(|bytes|
String::from_utf8_lossy(bytes.iter().as_slice())
).unwrap_or_else(|| "<none>".into())
);
}
}