diff --git a/Cargo.toml b/Cargo.toml index dc7a089..14a836c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust_util" -version = "0.6.3" +version = "0.6.5" authors = ["Hatter Jiang "] edition = "2018" description = "Hatter's Rust Util" diff --git a/src/lib.rs b/src/lib.rs index 6db3ce8..4348ae2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod util_str; pub mod util_size; pub mod util_file; pub mod util_time; +pub mod util_net; /// iff!(condition, result_when_true, result_when_false) #[macro_export] macro_rules! iff { diff --git a/src/util_net.rs b/src/util_net.rs new file mode 100644 index 0000000..58e3dd2 --- /dev/null +++ b/src/util_net.rs @@ -0,0 +1,95 @@ +use std::net::SocketAddr; +use crate::XResult; + +#[derive(Debug, Clone)] +pub enum IpAddressMask { + Ipv4([u8; 4], u8), +} + +impl IpAddressMask { + pub fn parse_ipv4(addr: &str) -> Option { + let addr_mask_parts = addr.split('/').collect::>(); + let (addr_ip, mask) = if addr_mask_parts.len() == 1 { + (addr_mask_parts[0], 32) + } else if addr_mask_parts.len() == 2 { + if let Ok(mask) = addr_mask_parts[1].parse::() { + (addr_mask_parts[0], mask) + } else { + return None; + } + } else { + return None; + }; + let addr_parts = addr_ip.split('.').collect::>(); + if addr_parts.len() != 4 { + return None; + } + let parsed_addr = || -> XResult<[u8; 4]> { + Ok([addr_parts[0].parse::()?, + addr_parts[1].parse::()?, + addr_parts[2].parse::()?, + addr_parts[3].parse::()? + ]) + }; + match parsed_addr() { + Ok(parts) => Some(IpAddressMask::Ipv4(parts, mask)), + Err(_) => None, + } + } + + pub fn to_address(&self) -> String { + match self { + IpAddressMask::Ipv4(ipv4, mask) => { + format!("{}/{}", ipv4.iter().map(|p| p.to_string()).collect::>().join("."), mask) + }, + } + } + + pub fn is_matches(&self, socket_addr: &SocketAddr) -> bool { + match socket_addr { + SocketAddr::V4(socket_addr_v4) => { + let socket_addr_v4_octets = socket_addr_v4.ip().octets(); + match self { + IpAddressMask::Ipv4(self_ipv4_octets, mask) => { + let self_ipv4_u32 = ipv4_to_u32(&self_ipv4_octets); + let addr_ipv4_u32 = ipv4_to_u32(&socket_addr_v4_octets); + let mask_u32 = ipv4_mask(*mask); + self_ipv4_u32 & mask_u32 == addr_ipv4_u32 & mask_u32 + }, + } + }, + SocketAddr::V6(_) => false, + } + } +} + +fn ipv4_mask(mask: u8) -> u32 { + let mut r = 0_u32; + for _ in 0..mask { + r <<= 1; + r |= 1; + } + for _ in mask..32 { + r <<= 1; + } + r +} + +fn ipv4_to_u32(ipv4: &[u8; 4]) -> u32 { + ((ipv4[0] as u32) << (8 * 3)) + ((ipv4[1] as u32) << (8 * 2)) + ((ipv4[2] as u32) << 8) + (ipv4[3] as u32) +} + + +#[test] +fn test_is_matches() { + let addr = SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), 123); + let addr2 = SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 2)), 123); + assert_eq!(true, IpAddressMask::parse_ipv4("127.0.0.1").unwrap().is_matches(&addr)); + assert_eq!(true, IpAddressMask::parse_ipv4("127.0.0.1/32").unwrap().is_matches(&addr)); + assert_eq!(true, IpAddressMask::parse_ipv4("127.0.0.1/31").unwrap().is_matches(&addr)); + assert_eq!(true, IpAddressMask::parse_ipv4("127.0.0.1/30").unwrap().is_matches(&addr)); + assert_eq!(false, IpAddressMask::parse_ipv4("127.0.0.1").unwrap().is_matches(&addr2)); + assert_eq!(false, IpAddressMask::parse_ipv4("127.0.0.1/32").unwrap().is_matches(&addr2)); + assert_eq!(false, IpAddressMask::parse_ipv4("127.0.0.1/31").unwrap().is_matches(&addr2)); + assert_eq!(true, IpAddressMask::parse_ipv4("127.0.0.1/30").unwrap().is_matches(&addr2)); +} \ No newline at end of file