Update examples with modernized code

This commit is contained in:
Emil Hernvall
2020-06-18 01:47:09 +02:00
parent 31369696d9
commit f815075ae4
10 changed files with 708 additions and 627 deletions

View File

@@ -11,6 +11,8 @@ pub struct BytePacketBuffer {
}
impl BytePacketBuffer {
/// This gives us a fresh buffer for holding the packet contents, and a
/// field for keeping track of where we are.
pub fn new() -> BytePacketBuffer {
BytePacketBuffer {
buf: [0; 512],
@@ -18,22 +20,26 @@ impl BytePacketBuffer {
}
}
/// Current position within buffer
fn pos(&self) -> usize {
self.pos
}
/// Step the buffer position forward a specific number of steps
fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps;
Ok(())
}
/// Change the buffer position
fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos;
Ok(())
}
/// Read a single byte and move the position one step forward
fn read(&mut self) -> Result<u8> {
if self.pos >= 512 {
return Err("End of buffer".into());
@@ -44,6 +50,7 @@ impl BytePacketBuffer {
Ok(res)
}
/// Get a single byte, without changing the buffer position
fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= 512 {
return Err("End of buffer".into());
@@ -51,6 +58,7 @@ impl BytePacketBuffer {
Ok(self.buf[pos])
}
/// Get a range of bytes
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= 512 {
return Err("End of buffer".into());
@@ -58,12 +66,14 @@ impl BytePacketBuffer {
Ok(&self.buf[start..start + len as usize])
}
/// Read two bytes, stepping two steps forward
fn read_u16(&mut self) -> Result<u16> {
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
Ok(res)
}
/// Read four bytes, stepping four steps forward
fn read_u32(&mut self) -> Result<u32> {
let res = ((self.read()? as u32) << 24)
| ((self.read()? as u32) << 16)
@@ -73,13 +83,28 @@ impl BytePacketBuffer {
Ok(res)
}
/// Read a qname
///
/// The tricky part: Reading domain names, taking labels into consideration.
/// Will take something like [3]www[6]google[3]com[0] and append
/// www.google.com to outstr.
fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
// Since we might encounter jumps, we'll keep track of our position
// locally as opposed to using the position within the struct. This
// allows us to move the shared position to a point past our current
// qname, while keeping track of our progress on the current qname
// using this variable.
let mut pos = self.pos();
let mut jumped = false;
let mut delim = "";
// track whether or not we've jumped
let mut jumped = false;
let max_jumps = 5;
let mut jumps_performed = 0;
// Our delimiter which we append for each label. Since we don't want a
// dot at the beginning of the domain name we'll leave it empty for now
// and set it to "." at the end of the first iteration.
let mut delim = "";
loop {
// Dns Packets are untrusted data, so we need to be paranoid. Someone
// can craft a packet with a cycle in the jump instructions. This guards
@@ -88,42 +113,56 @@ impl BytePacketBuffer {
return Err(format!("Limit of {} jumps exceeded", max_jumps).into());
}
// At this point, we're always at the beginning of a label. Recall
// that labels start with a length byte.
let len = self.get(pos)?;
// A two byte sequence, where the two highest bits of the first byte is
// set, represents a offset relative to the start of the buffer. We
// handle this by jumping to the offset, setting a flag to indicate
// that we shouldn't update the shared buffer position once done.
// If len has the two most significant bit are set, it represents a
// jump to some other offset in the packet:
if (len & 0xC0) == 0xC0 {
// When a jump is performed, we only modify the shared buffer
// position once, and avoid making the change later on.
// Update the buffer position to a point past the current
// label. We don't need to touch it any further.
if !jumped {
self.seek(pos + 2)?;
}
// Read another byte, calculate offset and perform the jump by
// updating our local position variable
let b2 = self.get(pos + 1)? as u16;
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
pos = offset as usize;
// Indicate that a jump was performed.
jumped = true;
jumps_performed += 1;
continue;
}
// The base scenario, where we're reading a single label and
// appending it to the output:
else {
// Move a single byte forward to move past the length byte.
pos += 1;
pos += 1;
// Domain names are terminated by an empty label of length 0,
// so if the length is zero we're done.
if len == 0 {
break;
}
// Names are terminated by an empty label of length 0
if len == 0 {
break;
// Append the delimiter to our output buffer first.
outstr.push_str(delim);
// Extract the actual ASCII bytes for this label and append them
// to the output buffer.
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
// Move forward the full length of the label.
pos += len as usize;
}
outstr.push_str(delim);
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
pos += len as usize;
}
if !jumped {
@@ -386,19 +425,19 @@ fn main() -> Result<()> {
f.read(&mut buffer.buf)?;
let packet = DnsPacket::from_buffer(&mut buffer)?;
println!("{:?}", packet.header);
println!("{:#?}", packet.header);
for q in packet.questions {
println!("{:?}", q);
println!("{:#?}", q);
}
for rec in packet.answers {
println!("{:?}", rec);
println!("{:#?}", rec);
}
for rec in packet.authorities {
println!("{:?}", rec);
println!("{:#?}", rec);
}
for rec in packet.resources {
println!("{:?}", rec);
println!("{:#?}", rec);
}
Ok(())

View File

@@ -164,9 +164,7 @@ impl BytePacketBuffer {
}
fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>();
for label in split_str {
for label in qname.split('.') {
let len = label.len();
if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into());
@@ -521,12 +519,18 @@ impl DnsPacket {
}
fn main() -> Result<()> {
let qname = "www.yahoo.com";
// Perform an A query for google.com
let qname = "google.com";
let qtype = QueryType::A;
// Using googles public DNS server
let server = ("8.8.8.8", 53);
// Bind a UDP socket to an arbitrary port
let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
// Build our query packet. It's important that we remember to set the
// `recursion_desired` flag. As noted earlier, the packet id is arbitrary.
let mut packet = DnsPacket::new();
packet.header.id = 6666;
@@ -536,27 +540,34 @@ fn main() -> Result<()> {
.questions
.push(DnsQuestion::new(qname.to_string(), qtype));
// Use our new write method to write the packet to a buffer...
let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer)?;
// ...and send it off to the server using our socket:
socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
// To prepare for receiving the response, we'll create a new `BytePacketBuffer`,
// and ask the socket to write the response directly into our buffer.
let mut res_buffer = BytePacketBuffer::new();
socket.recv_from(&mut res_buffer.buf)?;
// As per the previous section, `DnsPacket::from_buffer()` is then used to
// actually parse the packet after which we can print the response.
let res_packet = DnsPacket::from_buffer(&mut res_buffer)?;
println!("{:?}", res_packet.header);
println!("{:#?}", res_packet.header);
for q in res_packet.questions {
println!("{:?}", q);
println!("{:#?}", q);
}
for rec in res_packet.answers {
println!("{:?}", rec);
println!("{:#?}", rec);
}
for rec in res_packet.authorities {
println!("{:?}", rec);
println!("{:#?}", rec);
}
for rec in res_packet.resources {
println!("{:?}", rec);
println!("{:#?}", rec);
}
Ok(())

View File

@@ -164,9 +164,7 @@ impl BytePacketBuffer {
}
fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>();
for label in split_str {
for label in qname.split('.') {
let len = label.len();
if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into());

View File

@@ -164,9 +164,7 @@ impl BytePacketBuffer {
}
fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>();
for label in split_str {
for label in qname.split('.') {
let len = label.len();
if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into());
@@ -714,11 +712,17 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
}
fn main() -> Result<()> {
// Forward queries to Google's public DNS
let server = ("8.8.8.8", 53);
// Bind an UDP socket on port 2053
let socket = UdpSocket::bind(("0.0.0.0", 2053))?;
// For now, queries are handled sequentially, so an infinite loop for servicing
// requests is initiated.
loop {
// With a socket ready, we can go ahead and read a packet. This will
// block until one is received.
let mut req_buffer = BytePacketBuffer::new();
let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
Ok(x) => x,
@@ -728,6 +732,16 @@ fn main() -> Result<()> {
}
};
// Here we use match to safely unwrap the `Result`. If everything's as expected,
// the raw bytes are simply returned, and if not it'll abort by restarting the
// loop and waiting for the next request. The `recv_from` function will write the
// data into the provided buffer, and return the length of the data read as well
// as the source address. We're not interested in the length, but we need to keep
// track of the source in order to send our reply later on.
// Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
// a `DnsPacket`. It uses the same error handling idiom as the previous statement.
let request = match DnsPacket::from_buffer(&mut req_buffer) {
Ok(x) => x,
Err(e) => {
@@ -736,18 +750,29 @@ fn main() -> Result<()> {
}
};
// Create and initialize the response packet
let mut packet = DnsPacket::new();
packet.header.id = request.header.id;
packet.header.recursion_desired = true;
packet.header.recursion_available = true;
packet.header.response = true;
// Being mindful of how unreliable input data from arbitrary senders can be, we
// need make sure that a question is actually present. If not, we return `FORMERR`
// to indicate that the sender made something wrong.
if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR;
} else {
}
// Usually a question will be present, though.
else {
let question = &request.questions[0];
println!("Received query: {:?}", question);
// Since all is set up and as expected, the query can be forwarded to the target
// server. There's always the possibility that the query will fail, in which case
// the `SERVFAIL` response code is set to indicate as much to the client. If
// rather everything goes as planned, the question and response records as copied
// into our response packet.
if let Ok(result) = lookup(&question.name, question.qtype, server) {
packet.questions.push(question.clone());
packet.header.rescode = result.header.rescode;
@@ -769,6 +794,7 @@ fn main() -> Result<()> {
}
}
// The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new();
match packet.write(&mut res_buffer) {
Ok(_) => {}

View File

@@ -164,9 +164,7 @@ impl BytePacketBuffer {
}
fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>();
for label in split_str {
for label in qname.split('.') {
let len = label.len();
if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into());
@@ -690,84 +688,69 @@ impl DnsPacket {
Ok(())
}
/// It's useful to be able to pick a random A record from a packet. When we
/// get multiple IP's for a single name, it doesn't matter which one we
/// choose, so in those cases we can now pick one at random.
pub fn get_random_a(&self) -> Option<String> {
if !self.answers.is_empty() {
let a_record = &self.answers[0];
if let DnsRecord::A { ref addr, .. } = *a_record {
return Some(addr.to_string());
}
}
None
self.answers
.iter()
.filter_map(|record| match record {
DnsRecord::A { ref addr, .. } => Some(addr.to_string()),
_ => None,
})
.next()
}
/// A helper function which returns an iterator over all name servers in
/// the authorities section, represented as (domain, host) tuples
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
self.authorities
.iter()
// In practice, these are always NS records in well formed packages.
// Convert the NS records to a tuple which has only the data we need
// to make it easy to work with.
.filter_map(|record| match record {
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
_ => None,
})
// Discard servers which aren't authoritative to our query
.filter(move |(domain, _)| qname.ends_with(*domain))
}
/// We'll use the fact that name servers often bundle the corresponding
/// A records when replying to an NS query to implement a function that
/// returns the actual IP for an NS record if possible.
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> {
let mut new_authorities = Vec::new();
for auth in &self.authorities {
if let DnsRecord::NS {
ref domain,
ref host,
..
} = *auth
{
if !qname.ends_with(domain) {
continue;
}
for rsrc in &self.resources {
if let DnsRecord::A {
ref domain,
ref addr,
ttl,
} = *rsrc
{
if domain != host {
continue;
}
let rec = DnsRecord::A {
domain: host.clone(),
addr: *addr,
ttl: ttl,
};
new_authorities.push(rec);
}
}
}
}
if !new_authorities.is_empty() {
if let DnsRecord::A { addr, .. } = new_authorities[0] {
return Some(addr.to_string());
}
}
None
// Get an iterator over the nameservers in the authorities section
self.get_ns(qname)
// Now we need to look for a matching A record in the additional
// section. Since we just want the first valid record, we can just
// build a stream of matching records.
.flat_map(|(_, host)| {
self.resources
.iter()
// Filter for A records where the domain match the host
// of the NS record that we are currently processing
.filter_map(move |record| match record {
DnsRecord::A { domain, addr, .. } if domain == host => Some(addr),
_ => None,
})
})
.map(|addr| addr.to_string())
// Finally, pick the first valid entry
.next()
}
/// However, not all name servers are as that nice. In certain cases there won't
/// be any A records in the additional section, and we'll have to perform *another*
/// lookup in the midst. For this, we introduce a method for returning the host
/// name of an appropriate name server.
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> {
let mut new_authorities = Vec::new();
for auth in &self.authorities {
if let DnsRecord::NS {
ref domain,
ref host,
..
} = *auth
{
if !qname.ends_with(domain) {
continue;
}
new_authorities.push(host);
}
}
if !new_authorities.is_empty() {
return Some(new_authorities[0].clone());
}
None
// Get an iterator over the nameservers in the authorities section
self.get_ns(qname)
.map(|(_, host)| host.to_string())
// Finally, pick the first valid entry
.next()
}
}
@@ -794,45 +777,53 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
}
fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// For now we're always starting with *a.root-servers.net*.
let mut ns = "198.41.0.4".to_string();
// Start querying name servers
// Since it might take an arbitrary number of steps, we enter an unbounded loop.
loop {
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
// The next step is to send the query to the active server.
let ns_copy = ns.clone();
let server = (ns_copy.as_str(), 53);
let response = lookup(qname, qtype.clone(), server)?;
// If we've got an actual answer, we're done!
// If there are entries in the answer section, and no errors, we are done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
return Ok(response.clone());
}
// We might also get a `NXDOMAIN` reply, which is the authoritative name servers
// way of telling us that the name doesn't exist.
if response.header.rescode == ResultCode::NXDOMAIN {
return Ok(response.clone());
}
// Otherwise, try to find a new nameserver based on NS and a
// corresponding A record in the additional section
// Otherwise, we'll try to find a new nameserver based on NS and a corresponding A
// record in the additional section. If this succeeds, we can switch name server
// and retry the loop.
if let Some(new_ns) = response.get_resolved_ns(qname) {
// If there is such a record, we can retry the loop with that NS
ns = new_ns.clone();
continue;
}
// If not, we'll have to resolve the ip of a NS record
// If not, we'll have to resolve the ip of a NS record. If no NS records exist,
// we'll go with what the last server told us.
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => return Ok(response.clone()),
};
// Recursively resolve the NS
// Here we go down the rabbit hole by starting _another_ lookup sequence in the
// midst of our current one. Hopefully, this will give us the IP of an appropriate
// name server.
let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?;
// Pick a random IP and restart
// Finally, we pick a random ip from the result, and restart the loop. If no such
// record is available, we again return the last result we got.
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone();
} else {