pub mod header; mod parser_utils; mod question; pub mod record; use crate::{ dns_types::Class, message::{question::Entry, record::Record}, parse::parse_label, RecordType, }; use anyhow::Result as AResult; use bitvec::prelude::*; use header::Header; use nom::{ combinator::{map, map_res, peek}, error::Error, multi::{count, length_value}, number::complete::{be_u16, be_u32, be_u8}, sequence::tuple, IResult, }; use std::{ io::Read, net::{Ipv4Addr, Ipv6Addr}, }; use self::record::{RecordData, SoaData}; /// Defined by the spec /// UDP messages 512 octets or less pub(crate) const MAX_UDP_BYTES: usize = 512; /// Defined by the spec /// labels 63 octets or less const MAX_LABEL_BYTES: usize = 63; /// Defined by the spec /// names 255 octets or less const MAX_NAME_BYTES: usize = 255; const MAX_RECURSION_DEPTH: u8 = 20; #[derive(Debug)] pub struct Message { /// The header section is always present. The header includes fields that /// specify which of the remaining sections are present, and also specify /// whether the message is a query or a response, a standard query or some /// other opcode, etc. pub header: Header, // The question section contains fields that describe a // question to a name server. These fields are a query type (QTYPE), a // query class (QCLASS), and a query domain name (QNAME). pub question: Vec, // The last three // sections have the same format: a possibly empty list of concatenated // resource records (RRs). /// The answer section contains RRs that answer the question pub answer: Vec, /// the authority section contains RRs that point toward an /// authoritative name server; pub authority: Vec, /// the additional records section contains RRs /// which relate to the query, but are not strictly answers for the /// question. pub additional: Vec, } impl Message { pub(crate) fn new_query( id: u16, domain_name: String, record_type: RecordType, ) -> AResult { let name_len = domain_name.len(); if name_len > MAX_NAME_BYTES { anyhow::bail!( "Domain name is {name_len} bytes, which is over the max of {MAX_NAME_BYTES}" ); } let labels: Vec<_> = domain_name.split('.').map(|a| a.to_owned()).collect(); if labels.iter().any(|label| label.len() > MAX_LABEL_BYTES) { anyhow::bail!( "One of the labels in your domain is over the max of {MAX_LABEL_BYTES} bytes" ); } let msg = Message { header: Header::new_query(id), question: vec![Entry::new(labels, record_type)], answer: Vec::new(), authority: Vec::new(), additional: Vec::new(), }; Ok(msg) } fn serialize_bits(&self, bv: &mut BitVec) -> AResult<()> { self.header.serialize(bv); for q in &self.question { q.serialize(bv)?; } Ok(()) } pub fn serialize_bytes(&self) -> AResult> { let mut bv = BitVec::::new(); self.serialize_bits(&mut bv)?; let mut msg_bytes = Vec::with_capacity(MAX_UDP_BYTES); bv.as_bitslice().read_to_end(&mut msg_bytes).unwrap(); Ok(msg_bytes) } pub fn deserialize(input: Vec) -> anyhow::Result { let mp = MsgParser { input }; let slice = &mp.input[..]; let msg: Message = mp.parse_message(slice).unwrap().1; Ok(msg) } } struct MsgParser { input: Vec, } impl MsgParser { /// Returns a parser that can parse DNS record data of the given record type. fn parse_rdata<'i>( &self, record_type: RecordType, ) -> impl FnMut(&'i [u8]) -> IResult<&'i [u8], RecordData> + '_ { move |i| { let recursion_depth = 0; let record = match record_type { RecordType::A => map(tuple((be_u8, be_u8, be_u8, be_u8)), |(a, b, c, d)| { RecordData::A(Ipv4Addr::new(a, b, c, d)) })(i)?, RecordType::Aaaa => map( tuple(( be_u16, be_u16, be_u16, be_u16, be_u16, be_u16, be_u16, be_u16, )), |(a, b, c, d, e, f, g, h)| { RecordData::Aaaa(Ipv6Addr::new(a, b, c, d, e, f, g, h)) }, )(i)?, RecordType::Cname => { map(|i| self.parse_name(i, recursion_depth), RecordData::Cname)(i)? } RecordType::Ns => map(|i| self.parse_name(i, recursion_depth), RecordData::Ns)(i)?, RecordType::Soa => { let (i, mname) = self.parse_name(i, recursion_depth)?; let (i, rname) = self.parse_name(i, recursion_depth)?; let (i, serial) = be_u32(i)?; let (i, refresh) = be_u32(i)?; let (i, retry) = be_u32(i)?; let (i, expire) = be_u32(i)?; let rd = SoaData { mname, rname, serial, refresh, retry, expire, }; (i, RecordData::Soa(rd)) } }; Ok(record) } } /// Parse a domain name. fn parse_name<'i>( &self, mut input: &'i [u8], recursion_depth: u8, ) -> IResult<&'i [u8], String> { let mut name = String::new(); loop { let (i, first_byte) = peek(be_u8)(input)?; input = i; const POINTER_HEADER: u8 = 0b11000000; if first_byte >= POINTER_HEADER { // The message is using Message Compression: // This label is a pointer, and it ends the sequence of labels. // The remaining 14 bits are the offset that the pointer points at. // So, first, examine the 14 bits to find the offset of the next label. let dereference_pointer = |ptr| (ptr - ((POINTER_HEADER as u16) << 8)) as usize; let (i, next_label_offset) = map(be_u16, dereference_pointer)(input)?; if recursion_depth >= MAX_RECURSION_DEPTH { panic!("too many DNS message compression indirections!") } // Now, just parse a name from that offset. let (_, pointed_label) = self .parse_name(&self.input[next_label_offset..], recursion_depth + 1) .unwrap(); name += &pointed_label; input = i; break; } else { // This label is a literal. let (i, label) = parse_label(input)?; input = i; name += &label; // Domain names end with a zero-length terminal label. // (that's why in `dig` the names always end in an unnecessary dot, // e.g. adamchalmers.com.) if label.is_empty() { break; } name.push('.'); } } Ok((input, name)) } fn parse_record<'i>(&self, input: &'i [u8]) -> IResult<&'i [u8], Record, Error<&'i [u8]>> { let (input, name) = self.parse_name(input, 0)?; let (input, record_type) = map_res(be_u16, RecordType::try_from)(input)?; let (input, class) = map_res(be_u16, Class::try_from)(input)?; // RFC defines the max TTL as "positive values of a signed 32 bit number." let max_ttl: isize = i32::MAX.try_into().unwrap(); let (input, ttl) = map_res(be_u32, |ttl| { if (ttl as isize) > max_ttl { Err(format!("TTL {ttl} is too large")) } else { Ok(ttl) } })(input)?; let (i, data) = length_value(be_u16, self.parse_rdata(record_type))(input)?; Ok(( i, Record { name, class, ttl, data, }, )) } fn parse_message<'i>(&self, i: &'i [u8]) -> IResult<&'i [u8], Message, Error<&'i [u8]>> { // The Header parser requires parsing individual bits, because the RFC stores some boolean // flags as single bits, and some numbers as 4-bit numbers. // So, first convert the input from bytestream to bitstream, then run the Header parser, // then convert the bitstream back to a bystream for the following steps. let (i, header) = nom::bits::bits(Header::deserialize)(i)?; // Parse the right number of question sections. let (i, question) = count(question::Entry::deserialize, header.question_count.into())(i)?; // After the question comes the DNS records themselves. Parse the right number of each kind! let (i, answer) = count(|i| self.parse_record(i), header.answer_count.into())(i)?; let (i, authority) = count(|i| self.parse_record(i), header.name_server_count.into())(i)?; let (i, additional) = count( |i| self.parse_record(i), header.additional_records_count.into(), )(i)?; Ok(( i, Message { header, question, answer, authority, additional, }, )) } } #[cfg(test)] mod tests { use std::net::Ipv4Addr; use crate::{ dns_types::Class, message::record::{Record, RecordData}, }; use super::*; #[test] fn test_msg_with_soa_records() { let response_msg = vec![ 190, 125, 129, 128, 0, 1, 0, 0, 0, 1, 0, 0, 11, 115, 101, 114, 105, 111, 117, 115, 101, 97, 116, 115, 3, 99, 111, 109, 0, 0, 5, 0, 1, 192, 12, 0, 6, 0, 1, 0, 0, 1, 44, 0, 53, 4, 100, 110, 115, 49, 3, 112, 48, 51, 5, 110, 115, 111, 110, 101, 3, 110, 101, 116, 0, 10, 104, 111, 115, 116, 109, 97, 115, 116, 101, 114, 192, 54, 97, 247, 178, 208, 0, 0, 168, 192, 0, 0, 2, 88, 0, 9, 58, 128, 0, 0, 1, 44, ]; let msg = Message::deserialize(response_msg).unwrap(); assert_eq!(msg.header.name_server_count, 1); assert_eq!(msg.authority.len(), 1); } #[test] fn test_parse_msg() { let response_msg = vec![ 0, 33, 129, 128, 0, 1, 0, 2, 0, 0, 0, 0, // Header (12 bytes) 4, 98, 108, 111, 103, // blog 12, 97, 100, 97, 109, 99, 104, 97, 108, 109, 101, 114, 115, // adamchalmers 3, 99, 111, 109, // com 0, // . 0, 1, 0, 1, // class, type 192, 12, // Answer #1: name, which is a pointer to byte 12. 0, 1, 0, 1, // class, type 0, 0, 0, 179, // TTL (u32) 0, 4, // rdata length 104, 19, 237, 120, // rdata, an IPv4 192, 12, // Answer #1: name, which is a pointer to byte 12. 0, 1, 0, 1, // class, type 0, 0, 0, 179, // TTL (u32) 0, 4, // rdata length 104, 19, 238, 120, // IPv4 ]; // Try to parse it let actual_msg = Message::deserialize(response_msg).unwrap(); // Was it correct? let name = String::from("blog.adamchalmers.com."); let expected_answers = vec![ Record { name: name.clone(), class: Class::IN, ttl: 179, data: RecordData::A(Ipv4Addr::new(104, 19, 237, 120)), }, Record { name, class: Class::IN, ttl: 179, data: RecordData::A(Ipv4Addr::new(104, 19, 238, 120)), }, ]; let actual_answers = actual_msg.answer; assert_eq!(actual_answers, expected_answers) } }