From cadbcd9429a310e76a89b9d76d1b8e274c69d80c Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Sat, 30 Mar 2024 19:10:32 +0800 Subject: [PATCH] feat: v0.2.0-rc, optimize code --- src/app.rs | 63 ++++++++++++++++++++++++-------------------------- src/cert.rs | 13 +++++++---- src/main.rs | 26 +++++++++++++-------- src/service.rs | 45 ++++++++++++++++++++---------------- 4 files changed, 81 insertions(+), 66 deletions(-) diff --git a/src/app.rs b/src/app.rs index cd2fd36..31f9b4b 100644 --- a/src/app.rs +++ b/src/app.rs @@ -19,13 +19,13 @@ use super::service::HostConfig; pub struct ProxyApp { tls: bool, lookup_dns: bool, - host_configs: Vec, + host_config_map: HashMap, dns_resolver: TokioAsyncResolver, - dns_resolver_cache_map: RwLock>, + dns_resolver_cache: RwLock>, } impl ProxyApp { - pub fn new(tls: bool, lookup_dns: bool, host_configs: Vec) -> Self { + pub fn new(tls: bool, lookup_dns: bool, host_config_map: HashMap) -> Self { let dns_resolver = TokioAsyncResolver::tokio( ResolverConfig::default(), ResolverOpts::default(), @@ -33,16 +33,16 @@ impl ProxyApp { ProxyApp { tls, lookup_dns, - host_configs, + host_config_map, dns_resolver, - dns_resolver_cache_map: Default::default(), + dns_resolver_cache: Default::default(), } } // just only support IPv4 async fn lookup_ipv4(&self, hostname: &str) -> Option { { - if let Some(ipv4_address) = self.dns_resolver_cache_map.read().await.get(hostname) { + if let Some(ipv4_address) = self.dns_resolver_cache.read().await.get(hostname) { log::info!("DNS cached {} --> {}", hostname, ipv4_address); return Some(ipv4_address.to_string()); } @@ -55,7 +55,7 @@ impl ProxyApp { if let Some(RData::A(a)) = record.data() { let ipv4_address = a.0.to_string(); { - self.dns_resolver_cache_map.write().await + self.dns_resolver_cache.write().await .insert(hostname.to_string(), ipv4_address.clone()); } log::info!("DNS found {} --> {}", hostname, ipv4_address); @@ -90,51 +90,48 @@ impl ProxyHttp for ProxyApp { fn new_ctx(&self) {} async fn upstream_peer(&self, session: &mut Session, _ctx: &mut ()) -> Result> { - let host_header = session - .get_header(HeaderName::from_static("host")) - .unwrap() - .to_str() - .expect("get host from http header failed"); - log::info!("host header: {host_header}"); + let host_header = match session.get_header(HeaderName::from_static("host")) { + None => return Err(Error::new(ErrorType::HTTPStatus(400))), + Some(host_header) => host_header, + }; + let host = match host_header.to_str() { + Ok(host) => host, + Err(_) => return Err(Error::new(ErrorType::HTTPStatus(400))), + }; + log::info!("Find host header: {}", host); - let hostname = if host_header.contains(':') { - host_header.chars().take_while(|c| c != &':').collect() + let hostname = if host.contains(':') { + host.chars().take_while(|c| c != &':').collect() } else { - host_header.to_string() + host.to_string() }; - if hostname == "localhost" { - return Err(Error::new(ErrorType::CustomCode("bad host", 400))); + if hostname == "127.0.0.1" || hostname == "localhost" { + return Err(Error::new(ErrorType::HTTPStatus(404))); } - let host_config = self - .host_configs - .iter() - .find(|x| x.proxy_hostname == hostname); - if let Some(host_config) = host_config { - let peer = HttpPeer::new( + if let Some(host_config) = self.host_config_map.get(&hostname) { + log::info!("Find peer: {} --> {}", hostname, host_config.proxy_addr); + return Ok(Box::new(HttpPeer::new( host_config.proxy_addr.as_str(), host_config.proxy_tls, - host_config.proxy_hostname.clone(), - ); - log::info!("Find peer: {} --> {}", hostname, host_config.proxy_addr); - return Ok(Box::new(peer)); + host_config.proxy_servername.clone(), + ))); } if self.lookup_dns { if let Some(address) = self.lookup_ipv4(&hostname).await { let peer_addr = format!("{}:{}", address, if self.tls { 443 } else { 80 }); - let peer = HttpPeer::new( + log::info!("DNS peer: {} --> {}", hostname, peer_addr); + return Ok(Box::new(HttpPeer::new( &peer_addr, self.tls, hostname.to_string(), - ); - log::info!("Generate peer: {} --> {}", hostname, peer_addr); - return Ok(Box::new(peer)); + ))); } } - panic!("Cannot find peer: {}", hostname); + Err(Error::new(ErrorType::CustomCode("bad host", 400))) } async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result diff --git a/src/cert.rs b/src/cert.rs index 1aafc8f..591bb9c 100644 --- a/src/cert.rs +++ b/src/cert.rs @@ -25,11 +25,14 @@ pub fn load_certificate(cert_fn: &str, key_fn: &str) -> Result<(Certificate, Str Ok((cert, cert_pem)) } -pub fn issue_certificate(intermediate_certificate: &Certificate, domain: &str) -> Result { +pub fn issue_certificate(issuer_certificate: &Certificate, domain: &str) -> Result { let cert = new_end_entity(domain)?; log::info!("New certificate for: {} -> {}", domain, hex::encode(cert.get_key_identifier())); - let cert_pem = cert.serialize_pem_with_signer(intermediate_certificate).map_err(|e| format!("Sign cert failed: {}", e))?; + + let cert_pem = cert.serialize_pem_with_signer(issuer_certificate) + .map_err(|e| format!("Sign cert failed: {}", e))?; let key_pem = cert.serialize_private_key_pem(); + Ok(Cert { cert_pem, key_pem, @@ -71,7 +74,9 @@ fn new_end_entity(domain: &str) -> Result { } fn validity_period() -> Result<(OffsetDateTime, OffsetDateTime), String> { - let start = OffsetDateTime::now_utc().checked_sub(Duration::hours(1)).expect("SHOULD NOT HAPPEN!"); - let end = OffsetDateTime::now_utc().checked_add(Duration::days(90)).expect("SHOULD NOT HAPPEN!"); + let start = OffsetDateTime::now_utc().checked_sub(Duration::hours(1)) + .expect("SHOULD NOT HAPPEN!"); + let end = OffsetDateTime::now_utc().checked_add(Duration::days(90)) + .expect("SHOULD NOT HAPPEN!"); Ok((start, end)) } diff --git a/src/main.rs b/src/main.rs index 797a924..839d626 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use log::LevelFilter; use pingora::{ server::{configuration::Opt, Server}, @@ -32,6 +34,10 @@ pub fn main() { services.push(Box::new(prometheus_service_http)); } + if services.is_empty() { + panic!("No services is configured!"); + } + log::info!("start listen..."); my_server.add_services(services); my_server.run_forever(); @@ -47,8 +53,9 @@ fn build_services(server: &Server, proxy_config: &ProxyConfig) -> Vec> = vec![]; for group in &proxy_config.groups { let listen_address = format!("0.0.0.0:{}", group.port); - let host_configs = build_host_configs(group); + let host_config_map = build_host_config_map(group); let lookup_dns = group.lookup_dns.unwrap_or(false); + log::info!("Listen at: {}, tls: {}, lookup_dns: {}", listen_address, group.tls.is_some(), lookup_dns); match &group.tls { @@ -57,7 +64,7 @@ fn build_services(server: &Server, proxy_config: &ProxyConfig) -> Vec { @@ -66,7 +73,7 @@ fn build_services(server: &Server, proxy_config: &ProxyConfig) -> Vec Vec Vec { - let mut host_configs = vec![]; +fn build_host_config_map(group: &ProxyGroup) -> HashMap { + let mut host_config_map = HashMap::new(); if let Some(proxy_map) = &group.proxy_map { for (hostname, proxy_item) in proxy_map { - host_configs.push(HostConfig { + let host_config = HostConfig { proxy_addr: proxy_item.address.clone(), proxy_tls: proxy_item.tls.unwrap_or(false), - proxy_hostname: proxy_item.sni.clone().unwrap_or_else(|| hostname.clone()), - }); + proxy_servername: proxy_item.sni.clone().unwrap_or_else(|| hostname.clone()), + }; + host_config_map.insert(hostname.to_string(), host_config); } } - host_configs + host_config_map } \ No newline at end of file diff --git a/src/service.rs b/src/service.rs index abeb973..33468f9 100644 --- a/src/service.rs +++ b/src/service.rs @@ -17,30 +17,30 @@ use crate::cert::Cert; use crate::config::ProxyTls; struct Callback { - intermediate_certificate: Certificate, - intermediate_certificate_pem: String, - certificate_cache_map: RwLock>, + issuer_cert: Certificate, + issuer_cert_pem: String, + issued_cert_cache: RwLock>, } impl Callback { fn new(proxy_tls: &ProxyTls) -> Result { let (cert, cert_pem) = cert::load_certificate(&proxy_tls.issuer_cert, &proxy_tls.issuer_key)?; Ok(Self { - intermediate_certificate: cert, - intermediate_certificate_pem: cert_pem, - certificate_cache_map: Default::default(), + issuer_cert: cert, + issuer_cert_pem: cert_pem, + issued_cert_cache: Default::default(), }) } async fn issue_certificate(&self, hostname: &str) -> Result { { - if let Some(cert) = self.certificate_cache_map.read().await.get(hostname) { + if let Some(cert) = self.issued_cert_cache.read().await.get(hostname) { return Ok(cert.clone()); } } - let cert = cert::issue_certificate(&self.intermediate_certificate, hostname)?; + let cert = cert::issue_certificate(&self.issuer_cert, hostname)?; { - self.certificate_cache_map.write().await.insert(hostname.to_string(), cert.clone()); + self.issued_cert_cache.write().await.insert(hostname.to_string(), cert.clone()); } Ok(cert) } @@ -49,7 +49,8 @@ impl Callback { #[async_trait] impl TlsAccept for Callback { async fn certificate_callback(&self, ssl: &mut SslRef) -> () { - let sni_provided = ssl.servername(NameType::HOST_NAME).expect("get sni failed").to_string(); + let sni_provided = ssl.servername(NameType::HOST_NAME) + .unwrap_or("127.0.0.1").to_string(); log::info!("SNI provided: {}", sni_provided); let cert = self.issue_certificate(&sni_provided).await @@ -57,15 +58,15 @@ impl TlsAccept for Callback { let x509_cert = X509::from_pem(cert.cert_pem.as_bytes()) .unwrap_or_else(|e| panic!("parse cert: {} failed: {}", cert.cert_pem, e)); - let x509_intermediate_cert = X509::from_pem(self.intermediate_certificate_pem.as_bytes()) - .unwrap_or_else(|e| panic!("parse intermediate cert: {} failed: {}", self.intermediate_certificate_pem, e)); + let x509_intermediate_cert = X509::from_pem(self.issuer_cert_pem.as_bytes()) + .unwrap_or_else(|e| panic!("parse issuer cert: {} failed: {}", self.issuer_cert_pem, e)); let private_key = PKey::private_key_from_pem(cert.key_pem.as_bytes()) .unwrap_or_else(|e| panic!("parse key: {} failed: {}", cert.key_pem, e)); ext::ssl_use_certificate(ssl, &x509_cert) .unwrap_or_else(|e| panic!("apply certificate for: {} failed: {}", sni_provided, e)); ext::ssl_add_chain_cert(ssl, &x509_intermediate_cert) - .unwrap_or_else(|e| panic!("apply intermediate certificate for: {} failed: {}", sni_provided, e)); + .unwrap_or_else(|e| panic!("apply issuer certificate for: {} failed: {}", sni_provided, e)); ext::ssl_use_private_key(ssl, &private_key) .unwrap_or_else(|e| panic!("apply key for: {} failed: {}", sni_provided, e)); } @@ -75,16 +76,16 @@ impl TlsAccept for Callback { pub struct HostConfig { pub proxy_addr: String, pub proxy_tls: bool, - pub proxy_hostname: String, + pub proxy_servername: String, } pub fn proxy_service_tcp( server_conf: &Arc, listen_addr: &str, lookup_dns: bool, - host_configs: Vec, + host_config_map: HashMap, ) -> impl pingora::services::Service { - let proxy_app = ProxyApp::new(false, lookup_dns, host_configs); + let proxy_app = ProxyApp::new(false, lookup_dns, host_config_map); let mut service = http_proxy_service(server_conf, proxy_app); service.add_tcp(listen_addr); @@ -97,13 +98,17 @@ pub fn proxy_service_tls( listen_addr: &str, lookup_dns: bool, proxy_tls: &ProxyTls, - host_configs: Vec, + host_config_map: HashMap, ) -> impl pingora::services::Service { - let proxy_app = ProxyApp::new(true, lookup_dns, host_configs); + let proxy_app = ProxyApp::new(true, lookup_dns, host_config_map); let mut service = http_proxy_service(server_conf, proxy_app); - let cb = Box::new(Callback::new(proxy_tls).unwrap()); - let tls_settings = TlsSettings::with_callbacks(cb).unwrap(); + let cb = Box::new(Callback::new(proxy_tls).unwrap_or_else(|e| { + panic!("Init SSL callback failed: {}", e); + })); + let tls_settings = TlsSettings::with_callbacks(cb).unwrap_or_else(|e| { + panic!("Init SSL settings failed: {}", e); + }); service.add_tls_with_settings(listen_addr, None, tls_settings); service