Update examples with modernized code
This commit is contained in:
@@ -11,6 +11,8 @@ pub struct BytePacketBuffer {
|
||||
}
|
||||
|
||||
impl BytePacketBuffer {
|
||||
/// This gives us a fresh buffer for holding the packet contents, and a
|
||||
/// field for keeping track of where we are.
|
||||
pub fn new() -> BytePacketBuffer {
|
||||
BytePacketBuffer {
|
||||
buf: [0; 512],
|
||||
@@ -18,22 +20,26 @@ impl BytePacketBuffer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Current position within buffer
|
||||
fn pos(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
/// Step the buffer position forward a specific number of steps
|
||||
fn step(&mut self, steps: usize) -> Result<()> {
|
||||
self.pos += steps;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Change the buffer position
|
||||
fn seek(&mut self, pos: usize) -> Result<()> {
|
||||
self.pos = pos;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read a single byte and move the position one step forward
|
||||
fn read(&mut self) -> Result<u8> {
|
||||
if self.pos >= 512 {
|
||||
return Err("End of buffer".into());
|
||||
@@ -44,6 +50,7 @@ impl BytePacketBuffer {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Get a single byte, without changing the buffer position
|
||||
fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
if pos >= 512 {
|
||||
return Err("End of buffer".into());
|
||||
@@ -51,6 +58,7 @@ impl BytePacketBuffer {
|
||||
Ok(self.buf[pos])
|
||||
}
|
||||
|
||||
/// Get a range of bytes
|
||||
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
if start + len >= 512 {
|
||||
return Err("End of buffer".into());
|
||||
@@ -58,12 +66,14 @@ impl BytePacketBuffer {
|
||||
Ok(&self.buf[start..start + len as usize])
|
||||
}
|
||||
|
||||
/// Read two bytes, stepping two steps forward
|
||||
fn read_u16(&mut self) -> Result<u16> {
|
||||
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Read four bytes, stepping four steps forward
|
||||
fn read_u32(&mut self) -> Result<u32> {
|
||||
let res = ((self.read()? as u32) << 24)
|
||||
| ((self.read()? as u32) << 16)
|
||||
@@ -73,13 +83,28 @@ impl BytePacketBuffer {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Read a qname
|
||||
///
|
||||
/// The tricky part: Reading domain names, taking labels into consideration.
|
||||
/// Will take something like [3]www[6]google[3]com[0] and append
|
||||
/// www.google.com to outstr.
|
||||
fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
|
||||
// Since we might encounter jumps, we'll keep track of our position
|
||||
// locally as opposed to using the position within the struct. This
|
||||
// allows us to move the shared position to a point past our current
|
||||
// qname, while keeping track of our progress on the current qname
|
||||
// using this variable.
|
||||
let mut pos = self.pos();
|
||||
let mut jumped = false;
|
||||
|
||||
let mut delim = "";
|
||||
// track whether or not we've jumped
|
||||
let mut jumped = false;
|
||||
let max_jumps = 5;
|
||||
let mut jumps_performed = 0;
|
||||
|
||||
// Our delimiter which we append for each label. Since we don't want a
|
||||
// dot at the beginning of the domain name we'll leave it empty for now
|
||||
// and set it to "." at the end of the first iteration.
|
||||
let mut delim = "";
|
||||
loop {
|
||||
// Dns Packets are untrusted data, so we need to be paranoid. Someone
|
||||
// can craft a packet with a cycle in the jump instructions. This guards
|
||||
@@ -88,42 +113,56 @@ impl BytePacketBuffer {
|
||||
return Err(format!("Limit of {} jumps exceeded", max_jumps).into());
|
||||
}
|
||||
|
||||
// At this point, we're always at the beginning of a label. Recall
|
||||
// that labels start with a length byte.
|
||||
let len = self.get(pos)?;
|
||||
|
||||
// A two byte sequence, where the two highest bits of the first byte is
|
||||
// set, represents a offset relative to the start of the buffer. We
|
||||
// handle this by jumping to the offset, setting a flag to indicate
|
||||
// that we shouldn't update the shared buffer position once done.
|
||||
// If len has the two most significant bit are set, it represents a
|
||||
// jump to some other offset in the packet:
|
||||
if (len & 0xC0) == 0xC0 {
|
||||
// When a jump is performed, we only modify the shared buffer
|
||||
// position once, and avoid making the change later on.
|
||||
// Update the buffer position to a point past the current
|
||||
// label. We don't need to touch it any further.
|
||||
if !jumped {
|
||||
self.seek(pos + 2)?;
|
||||
}
|
||||
|
||||
// Read another byte, calculate offset and perform the jump by
|
||||
// updating our local position variable
|
||||
let b2 = self.get(pos + 1)? as u16;
|
||||
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
|
||||
pos = offset as usize;
|
||||
|
||||
// Indicate that a jump was performed.
|
||||
jumped = true;
|
||||
jumps_performed += 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
// The base scenario, where we're reading a single label and
|
||||
// appending it to the output:
|
||||
else {
|
||||
// Move a single byte forward to move past the length byte.
|
||||
pos += 1;
|
||||
|
||||
pos += 1;
|
||||
// Domain names are terminated by an empty label of length 0,
|
||||
// so if the length is zero we're done.
|
||||
if len == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
// Names are terminated by an empty label of length 0
|
||||
if len == 0 {
|
||||
break;
|
||||
// Append the delimiter to our output buffer first.
|
||||
outstr.push_str(delim);
|
||||
|
||||
// Extract the actual ASCII bytes for this label and append them
|
||||
// to the output buffer.
|
||||
let str_buffer = self.get_range(pos, len as usize)?;
|
||||
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
|
||||
|
||||
delim = ".";
|
||||
|
||||
// Move forward the full length of the label.
|
||||
pos += len as usize;
|
||||
}
|
||||
|
||||
outstr.push_str(delim);
|
||||
|
||||
let str_buffer = self.get_range(pos, len as usize)?;
|
||||
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
|
||||
|
||||
delim = ".";
|
||||
|
||||
pos += len as usize;
|
||||
}
|
||||
|
||||
if !jumped {
|
||||
@@ -386,19 +425,19 @@ fn main() -> Result<()> {
|
||||
f.read(&mut buffer.buf)?;
|
||||
|
||||
let packet = DnsPacket::from_buffer(&mut buffer)?;
|
||||
println!("{:?}", packet.header);
|
||||
println!("{:#?}", packet.header);
|
||||
|
||||
for q in packet.questions {
|
||||
println!("{:?}", q);
|
||||
println!("{:#?}", q);
|
||||
}
|
||||
for rec in packet.answers {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
for rec in packet.authorities {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
for rec in packet.resources {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -164,9 +164,7 @@ impl BytePacketBuffer {
|
||||
}
|
||||
|
||||
fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
let split_str = qname.split('.').collect::<Vec<&str>>();
|
||||
|
||||
for label in split_str {
|
||||
for label in qname.split('.') {
|
||||
let len = label.len();
|
||||
if len > 0x34 {
|
||||
return Err("Single label exceeds 63 characters of length".into());
|
||||
@@ -521,12 +519,18 @@ impl DnsPacket {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let qname = "www.yahoo.com";
|
||||
// Perform an A query for google.com
|
||||
let qname = "google.com";
|
||||
let qtype = QueryType::A;
|
||||
|
||||
// Using googles public DNS server
|
||||
let server = ("8.8.8.8", 53);
|
||||
|
||||
// Bind a UDP socket to an arbitrary port
|
||||
let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
|
||||
|
||||
// Build our query packet. It's important that we remember to set the
|
||||
// `recursion_desired` flag. As noted earlier, the packet id is arbitrary.
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
packet.header.id = 6666;
|
||||
@@ -536,27 +540,34 @@ fn main() -> Result<()> {
|
||||
.questions
|
||||
.push(DnsQuestion::new(qname.to_string(), qtype));
|
||||
|
||||
// Use our new write method to write the packet to a buffer...
|
||||
let mut req_buffer = BytePacketBuffer::new();
|
||||
packet.write(&mut req_buffer)?;
|
||||
|
||||
// ...and send it off to the server using our socket:
|
||||
socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
|
||||
|
||||
// To prepare for receiving the response, we'll create a new `BytePacketBuffer`,
|
||||
// and ask the socket to write the response directly into our buffer.
|
||||
let mut res_buffer = BytePacketBuffer::new();
|
||||
socket.recv_from(&mut res_buffer.buf)?;
|
||||
|
||||
// As per the previous section, `DnsPacket::from_buffer()` is then used to
|
||||
// actually parse the packet after which we can print the response.
|
||||
let res_packet = DnsPacket::from_buffer(&mut res_buffer)?;
|
||||
println!("{:?}", res_packet.header);
|
||||
println!("{:#?}", res_packet.header);
|
||||
|
||||
for q in res_packet.questions {
|
||||
println!("{:?}", q);
|
||||
println!("{:#?}", q);
|
||||
}
|
||||
for rec in res_packet.answers {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
for rec in res_packet.authorities {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
for rec in res_packet.resources {
|
||||
println!("{:?}", rec);
|
||||
println!("{:#?}", rec);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -164,9 +164,7 @@ impl BytePacketBuffer {
|
||||
}
|
||||
|
||||
fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
let split_str = qname.split('.').collect::<Vec<&str>>();
|
||||
|
||||
for label in split_str {
|
||||
for label in qname.split('.') {
|
||||
let len = label.len();
|
||||
if len > 0x34 {
|
||||
return Err("Single label exceeds 63 characters of length".into());
|
||||
|
||||
@@ -164,9 +164,7 @@ impl BytePacketBuffer {
|
||||
}
|
||||
|
||||
fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
let split_str = qname.split('.').collect::<Vec<&str>>();
|
||||
|
||||
for label in split_str {
|
||||
for label in qname.split('.') {
|
||||
let len = label.len();
|
||||
if len > 0x34 {
|
||||
return Err("Single label exceeds 63 characters of length".into());
|
||||
@@ -714,11 +712,17 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Forward queries to Google's public DNS
|
||||
let server = ("8.8.8.8", 53);
|
||||
|
||||
// Bind an UDP socket on port 2053
|
||||
let socket = UdpSocket::bind(("0.0.0.0", 2053))?;
|
||||
|
||||
// For now, queries are handled sequentially, so an infinite loop for servicing
|
||||
// requests is initiated.
|
||||
loop {
|
||||
// With a socket ready, we can go ahead and read a packet. This will
|
||||
// block until one is received.
|
||||
let mut req_buffer = BytePacketBuffer::new();
|
||||
let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
|
||||
Ok(x) => x,
|
||||
@@ -728,6 +732,16 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
// Here we use match to safely unwrap the `Result`. If everything's as expected,
|
||||
// the raw bytes are simply returned, and if not it'll abort by restarting the
|
||||
// loop and waiting for the next request. The `recv_from` function will write the
|
||||
// data into the provided buffer, and return the length of the data read as well
|
||||
// as the source address. We're not interested in the length, but we need to keep
|
||||
// track of the source in order to send our reply later on.
|
||||
|
||||
// Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
|
||||
// a `DnsPacket`. It uses the same error handling idiom as the previous statement.
|
||||
|
||||
let request = match DnsPacket::from_buffer(&mut req_buffer) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
@@ -736,18 +750,29 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
// Create and initialize the response packet
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.id = request.header.id;
|
||||
packet.header.recursion_desired = true;
|
||||
packet.header.recursion_available = true;
|
||||
packet.header.response = true;
|
||||
|
||||
// Being mindful of how unreliable input data from arbitrary senders can be, we
|
||||
// need make sure that a question is actually present. If not, we return `FORMERR`
|
||||
// to indicate that the sender made something wrong.
|
||||
if request.questions.is_empty() {
|
||||
packet.header.rescode = ResultCode::FORMERR;
|
||||
} else {
|
||||
}
|
||||
// Usually a question will be present, though.
|
||||
else {
|
||||
let question = &request.questions[0];
|
||||
println!("Received query: {:?}", question);
|
||||
|
||||
// Since all is set up and as expected, the query can be forwarded to the target
|
||||
// server. There's always the possibility that the query will fail, in which case
|
||||
// the `SERVFAIL` response code is set to indicate as much to the client. If
|
||||
// rather everything goes as planned, the question and response records as copied
|
||||
// into our response packet.
|
||||
if let Ok(result) = lookup(&question.name, question.qtype, server) {
|
||||
packet.questions.push(question.clone());
|
||||
packet.header.rescode = result.header.rescode;
|
||||
@@ -769,6 +794,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// The only thing remaining is to encode our response and send it off!
|
||||
let mut res_buffer = BytePacketBuffer::new();
|
||||
match packet.write(&mut res_buffer) {
|
||||
Ok(_) => {}
|
||||
|
||||
@@ -164,9 +164,7 @@ impl BytePacketBuffer {
|
||||
}
|
||||
|
||||
fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
let split_str = qname.split('.').collect::<Vec<&str>>();
|
||||
|
||||
for label in split_str {
|
||||
for label in qname.split('.') {
|
||||
let len = label.len();
|
||||
if len > 0x34 {
|
||||
return Err("Single label exceeds 63 characters of length".into());
|
||||
@@ -690,84 +688,69 @@ impl DnsPacket {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// It's useful to be able to pick a random A record from a packet. When we
|
||||
/// get multiple IP's for a single name, it doesn't matter which one we
|
||||
/// choose, so in those cases we can now pick one at random.
|
||||
pub fn get_random_a(&self) -> Option<String> {
|
||||
if !self.answers.is_empty() {
|
||||
let a_record = &self.answers[0];
|
||||
if let DnsRecord::A { ref addr, .. } = *a_record {
|
||||
return Some(addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
self.answers
|
||||
.iter()
|
||||
.filter_map(|record| match record {
|
||||
DnsRecord::A { ref addr, .. } => Some(addr.to_string()),
|
||||
_ => None,
|
||||
})
|
||||
.next()
|
||||
}
|
||||
|
||||
/// A helper function which returns an iterator over all name servers in
|
||||
/// the authorities section, represented as (domain, host) tuples
|
||||
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
|
||||
self.authorities
|
||||
.iter()
|
||||
// In practice, these are always NS records in well formed packages.
|
||||
// Convert the NS records to a tuple which has only the data we need
|
||||
// to make it easy to work with.
|
||||
.filter_map(|record| match record {
|
||||
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
|
||||
_ => None,
|
||||
})
|
||||
// Discard servers which aren't authoritative to our query
|
||||
.filter(move |(domain, _)| qname.ends_with(*domain))
|
||||
}
|
||||
|
||||
/// We'll use the fact that name servers often bundle the corresponding
|
||||
/// A records when replying to an NS query to implement a function that
|
||||
/// returns the actual IP for an NS record if possible.
|
||||
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> {
|
||||
let mut new_authorities = Vec::new();
|
||||
for auth in &self.authorities {
|
||||
if let DnsRecord::NS {
|
||||
ref domain,
|
||||
ref host,
|
||||
..
|
||||
} = *auth
|
||||
{
|
||||
if !qname.ends_with(domain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for rsrc in &self.resources {
|
||||
if let DnsRecord::A {
|
||||
ref domain,
|
||||
ref addr,
|
||||
ttl,
|
||||
} = *rsrc
|
||||
{
|
||||
if domain != host {
|
||||
continue;
|
||||
}
|
||||
|
||||
let rec = DnsRecord::A {
|
||||
domain: host.clone(),
|
||||
addr: *addr,
|
||||
ttl: ttl,
|
||||
};
|
||||
|
||||
new_authorities.push(rec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !new_authorities.is_empty() {
|
||||
if let DnsRecord::A { addr, .. } = new_authorities[0] {
|
||||
return Some(addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
// Get an iterator over the nameservers in the authorities section
|
||||
self.get_ns(qname)
|
||||
// Now we need to look for a matching A record in the additional
|
||||
// section. Since we just want the first valid record, we can just
|
||||
// build a stream of matching records.
|
||||
.flat_map(|(_, host)| {
|
||||
self.resources
|
||||
.iter()
|
||||
// Filter for A records where the domain match the host
|
||||
// of the NS record that we are currently processing
|
||||
.filter_map(move |record| match record {
|
||||
DnsRecord::A { domain, addr, .. } if domain == host => Some(addr),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.map(|addr| addr.to_string())
|
||||
// Finally, pick the first valid entry
|
||||
.next()
|
||||
}
|
||||
|
||||
/// However, not all name servers are as that nice. In certain cases there won't
|
||||
/// be any A records in the additional section, and we'll have to perform *another*
|
||||
/// lookup in the midst. For this, we introduce a method for returning the host
|
||||
/// name of an appropriate name server.
|
||||
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> {
|
||||
let mut new_authorities = Vec::new();
|
||||
for auth in &self.authorities {
|
||||
if let DnsRecord::NS {
|
||||
ref domain,
|
||||
ref host,
|
||||
..
|
||||
} = *auth
|
||||
{
|
||||
if !qname.ends_with(domain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
new_authorities.push(host);
|
||||
}
|
||||
}
|
||||
|
||||
if !new_authorities.is_empty() {
|
||||
return Some(new_authorities[0].clone());
|
||||
}
|
||||
|
||||
None
|
||||
// Get an iterator over the nameservers in the authorities section
|
||||
self.get_ns(qname)
|
||||
.map(|(_, host)| host.to_string())
|
||||
// Finally, pick the first valid entry
|
||||
.next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -794,45 +777,53 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
|
||||
}
|
||||
|
||||
fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
|
||||
// For now we're always starting with *a.root-servers.net*.
|
||||
let mut ns = "198.41.0.4".to_string();
|
||||
|
||||
// Start querying name servers
|
||||
// Since it might take an arbitrary number of steps, we enter an unbounded loop.
|
||||
loop {
|
||||
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
|
||||
|
||||
// The next step is to send the query to the active server.
|
||||
let ns_copy = ns.clone();
|
||||
|
||||
let server = (ns_copy.as_str(), 53);
|
||||
let response = lookup(qname, qtype.clone(), server)?;
|
||||
|
||||
// If we've got an actual answer, we're done!
|
||||
// If there are entries in the answer section, and no errors, we are done!
|
||||
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
|
||||
return Ok(response.clone());
|
||||
}
|
||||
|
||||
// We might also get a `NXDOMAIN` reply, which is the authoritative name servers
|
||||
// way of telling us that the name doesn't exist.
|
||||
if response.header.rescode == ResultCode::NXDOMAIN {
|
||||
return Ok(response.clone());
|
||||
}
|
||||
|
||||
// Otherwise, try to find a new nameserver based on NS and a
|
||||
// corresponding A record in the additional section
|
||||
// Otherwise, we'll try to find a new nameserver based on NS and a corresponding A
|
||||
// record in the additional section. If this succeeds, we can switch name server
|
||||
// and retry the loop.
|
||||
if let Some(new_ns) = response.get_resolved_ns(qname) {
|
||||
// If there is such a record, we can retry the loop with that NS
|
||||
ns = new_ns.clone();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// If not, we'll have to resolve the ip of a NS record
|
||||
// If not, we'll have to resolve the ip of a NS record. If no NS records exist,
|
||||
// we'll go with what the last server told us.
|
||||
let new_ns_name = match response.get_unresolved_ns(qname) {
|
||||
Some(x) => x,
|
||||
None => return Ok(response.clone()),
|
||||
};
|
||||
|
||||
// Recursively resolve the NS
|
||||
// Here we go down the rabbit hole by starting _another_ lookup sequence in the
|
||||
// midst of our current one. Hopefully, this will give us the IP of an appropriate
|
||||
// name server.
|
||||
let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?;
|
||||
|
||||
// Pick a random IP and restart
|
||||
// Finally, we pick a random ip from the result, and restart the loop. If no such
|
||||
// record is available, we again return the last result we got.
|
||||
if let Some(new_ns) = recursive_response.get_random_a() {
|
||||
ns = new_ns.clone();
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user