From 92eae7471303e81d2568300dc63e8003a80d7c2e Mon Sep 17 00:00:00 2001 From: keepsimple1 Date: Sat, 10 Aug 2024 14:38:04 -0700 Subject: [PATCH] add support for Known Answer Suppression part 2: multi-packet: querier side (#232) --- src/dns_parser.rs | 275 +++++++++++++++++++++++++++++++++--------- src/service_daemon.rs | 47 ++++---- 2 files changed, 242 insertions(+), 80 deletions(-) diff --git a/src/dns_parser.rs b/src/dns_parser.rs index 1c69ab9..38f04f5 100644 --- a/src/dns_parser.rs +++ b/src/dns_parser.rs @@ -1,8 +1,8 @@ //! DNS parsing utility. //! //! [DnsIncoming] is the logic representation of an incoming DNS packet. -//! [DnsOutgoing] is the logic representation of an outgoing DNS packet. -//! [DnsOutPacket] is the encoded packet for [DnsOutgoing]. +//! [DnsOutgoing] is the logic representation of an outgoing DNS message of one or more packets. +//! [DnsOutPacket] is the encoded one packet for [DnsOutgoing]. #[cfg(feature = "logging")] use crate::log::debug; @@ -37,6 +37,8 @@ pub const CLASS_CACHE_FLUSH: u16 = 0x8000; /// Reference: RFC6762: https://datatracker.ietf.org/doc/html/rfc6762#section-17 pub const MAX_MSG_ABSOLUTE: usize = 8972; +const MSG_HEADER_LEN: usize = 12; + // Definitions for DNS message header "flags" field // // The "flags" field is 16-bit long, in this format: @@ -48,9 +50,23 @@ pub const MAX_MSG_ABSOLUTE: usize = 8972; pub const FLAGS_QR_MASK: u16 = 0x8000; // mask for query/response bit pub const FLAGS_QR_QUERY: u16 = 0x0000; pub const FLAGS_QR_RESPONSE: u16 = 0x8000; -pub const FLAGS_AA: u16 = 0x0400; // mask for Authoritative answer bit -pub type DnsRecordBox = Box; +/// mask for Authoritative answer bit +pub const FLAGS_AA: u16 = 0x0400; + +/// mask for TC(Truncated) bit +/// +/// 2024-08-10: currently this flag is only supported on the querier side, +/// not supported on the responder side. I.e. the responder only +/// handles the first packet and ignore this bit. Since the +/// additional packets have 0 questions, the processing of them +/// is no-op. +/// In practice, this means the responder supports Known-Answer +/// only with single packet, not multi-packet. The querier supports +/// both single packet and multi-packet. +pub const FLAGS_TC: u16 = 0x0200; + +pub(crate) type DnsRecordBox = Box; #[inline] pub const fn ip_address_to_type(address: &IpAddr) -> u16 { @@ -220,7 +236,7 @@ impl PartialEq for DnsRecord { } } -pub trait DnsRecordExt: fmt::Debug { +pub(crate) trait DnsRecordExt: fmt::Debug { fn get_record(&self) -> &DnsRecord; fn get_record_mut(&mut self) -> &mut DnsRecord; fn write(&self, packet: &mut DnsOutPacket); @@ -636,18 +652,26 @@ enum PacketState { Finished = 1, } -pub struct DnsOutPacket { - pub(crate) data: Vec>, +/// A single packet for outgoing DNS message. +pub(crate) struct DnsOutPacket { + /// All bytes in `data` concatenated is the actual packet on the wire. + data: Vec>, + + /// Current logical size of the packet. It starts with the size of the mandatory header. size: usize, + + /// An internal state, not defined by DNS. state: PacketState, - names: HashMap, // k: name, v: offset + + /// k: name, v: offset + names: HashMap, } impl DnsOutPacket { - pub(crate) fn new() -> Self { + fn new() -> Self { Self { data: Vec::new(), - size: 12, + size: MSG_HEADER_LEN, // Header is mandatory. state: PacketState::Init, names: HashMap::new(), } @@ -660,12 +684,9 @@ impl DnsOutPacket { } /// Writes a record (answer, authoritative answer, additional) - /// Returns true if a record is written successfully, otherwise false. + /// Returns false if the packet exceeds the max size with this record, nothing is written to the packet. + /// otherwise returns true. fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool { - if self.state == PacketState::Finished { - return false; - } - let start_data_length = self.data.len(); let start_size = self.size; @@ -803,9 +824,53 @@ impl DnsOutPacket { self.data.push(vec![byte]); self.size += 1; } + + /// Writes the header fields and finish the packet. + /// This function should be only called when finishing a packet. + /// + /// The header format is based on RFC 1035 section 4.1.1: + /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 + // + // 1 1 1 1 1 1 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | ID | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | QDCOUNT | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | ANCOUNT | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | NSCOUNT | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | ARCOUNT | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // + fn write_header( + &mut self, + id: u16, + flags: u16, + q_count: u16, + a_count: u16, + auth_count: u16, + addi_count: u16, + ) { + self.insert_short(0, addi_count); + self.insert_short(0, auth_count); + self.insert_short(0, a_count); + self.insert_short(0, q_count); + self.insert_short(0, flags); + self.insert_short(0, id); + + // Adjust the size as it was already initialized to include the header. + self.size -= MSG_HEADER_LEN; + + self.state = PacketState::Finished; + } } -/// Representation of an outgoing packet. The actual encoded packet +/// Representation of one or more outgoing packet(s). The actual encoded packet /// is [DnsOutPacket]. pub(crate) struct DnsOutgoing { flags: u16, @@ -836,7 +901,7 @@ impl DnsOutgoing { (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY } - const fn _is_response(&self) -> bool { + const fn is_response(&self) -> bool { (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE } @@ -1000,45 +1065,84 @@ impl DnsOutgoing { self.questions.push(q); } - pub(crate) fn to_packet_data(&self) -> Vec { + /// Returns a list of actual DNS packet data to be sent on the wire. + pub(crate) fn to_data_on_wire(&self) -> Vec> { + let packet_list = self.to_packets(); + packet_list.iter().map(|p| p.data.concat()).collect() + } + + /// Encode self into one or more packets. + pub(crate) fn to_packets(&self) -> Vec { + let mut packet_list = Vec::new(); let mut packet = DnsOutPacket::new(); - if packet.state != PacketState::Finished { - for question in self.questions.iter() { - packet.write_question(question); - } - let mut answer_count = 0; - for (answer, time) in self.answers.iter() { - if packet.write_record(answer.as_ref(), *time) { - answer_count += 1; - } - } + let mut question_count = self.questions.len() as u16; + let mut answer_count = 0; + let mut auth_count = 0; + let mut addi_count = 0; + let id = if self.multicast { 0 } else { self.id }; - let mut auth_count = 0; - for auth in self.authorities.iter() { - auth_count += u16::from(packet.write_record(auth, 0)); + for question in self.questions.iter() { + packet.write_question(question); + } + + for (answer, time) in self.answers.iter() { + if packet.write_record(answer.as_ref(), *time) { + answer_count += 1; } + } - let mut addi_count = 0; - for addi in self.additionals.iter() { - addi_count += u16::from(packet.write_record(addi.as_ref(), 0)); + for auth in self.authorities.iter() { + auth_count += u16::from(packet.write_record(auth, 0)); + } + + for addi in self.additionals.iter() { + if packet.write_record(addi.as_ref(), 0) { + addi_count += 1; + continue; } - packet.state = PacketState::Finished; - - packet.insert_short(0, addi_count); - packet.insert_short(0, auth_count); - packet.insert_short(0, answer_count); - packet.insert_short(0, self.questions.len() as u16); - packet.insert_short(0, self.flags); - if self.multicast { - packet.insert_short(0, 0); - } else { - packet.insert_short(0, self.id); + // No more processing for response packets. + if self.is_response() { + break; } + + // For query, the current packet exceeds its max size due to known answers, + // need to truncate. + + // finish the current packet first. + packet.write_header( + id, + self.flags | FLAGS_TC, + question_count, + answer_count, + auth_count, + addi_count, + ); + + packet_list.push(packet); + + // create a new packet and reset counts. + packet = DnsOutPacket::new(); + packet.write_record(addi.as_ref(), 0); + + question_count = 0; + answer_count = 0; + auth_count = 0; + addi_count = 1; } - packet.data.concat() + packet.write_header( + id, + self.flags, + question_count, + answer_count, + auth_count, + addi_count, + ); + + packet_list.push(packet); + packet_list } } @@ -1059,8 +1163,6 @@ pub struct DnsIncoming { } impl DnsIncoming { - const HEADER_LEN: usize = 12; - pub(crate) fn new(data: Vec) -> Result { let mut incoming = Self { offset: 0, @@ -1090,7 +1192,7 @@ impl DnsIncoming { } fn read_header(&mut self) -> Result<()> { - if self.data.len() < Self::HEADER_LEN { + if self.data.len() < MSG_HEADER_LEN { return Err(Error::Msg(format!( "DNS incoming: header is too short: {} bytes", self.data.len() @@ -1105,7 +1207,7 @@ impl DnsIncoming { self.num_authorities = u16_from_be_slice(&data[8..10]); self.num_additionals = u16_from_be_slice(&data[10..12]); - self.offset = Self::HEADER_LEN; + self.offset = MSG_HEADER_LEN; debug!( "read_header: id {}, {} questions {} answers {} authorities {} additionals", @@ -1491,12 +1593,10 @@ const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 { #[cfg(test)] mod tests { - use crate::dns_parser::get_expiration_time; - use super::{ - current_time_millis, DnsIncoming, DnsNSec, DnsOutgoing, DnsRecordExt, DnsSrv, - CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, TYPE_A, TYPE_AAAA, - TYPE_PTR, + current_time_millis, get_expiration_time, DnsIncoming, DnsNSec, DnsOutgoing, DnsPointer, + DnsRecordExt, DnsSrv, CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, + MSG_HEADER_LEN, TYPE_A, TYPE_AAAA, TYPE_PTR, }; #[test] @@ -1504,7 +1604,7 @@ mod tests { let name = "test_read"; let mut out = DnsOutgoing::new(FLAGS_QR_QUERY); out.add_question(name, TYPE_PTR); - let data = out.to_packet_data(); + let data = out.to_data_on_wire().remove(0); // construct invalid data. let max_len = data.len() as u8; @@ -1550,7 +1650,7 @@ mod tests { 9000, "instance1".to_string(), )); - let data = response.to_packet_data(); + let data = response.to_data_on_wire().remove(0); let mut data_too_short = data.clone(); // verify the original data is good. @@ -1558,7 +1658,7 @@ mod tests { assert!(incoming.is_ok()); // verify that truncated data will cause an error. - data_too_short.truncate(DnsIncoming::HEADER_LEN + name.len() + 2); + data_too_short.truncate(MSG_HEADER_LEN + name.len() + 2); let invalid = DnsIncoming::new(data_too_short); assert!(invalid.is_err()); if let Err(e) = invalid { @@ -1580,7 +1680,7 @@ mod tests { 9000, host.to_string(), )); - let data = response.to_packet_data(); + let data = response.to_data_on_wire().remove(0); let data_len = data.len(); let mut data_too_short = data.clone(); @@ -1644,4 +1744,61 @@ mod tests { let new_refresh = get_expiration_time(dns_record.get_created(), dns_record.ttl, 85); assert_eq!(new_refresh, dns_record.get_refresh_time()); } + + #[test] + fn test_packet_size() { + let mut outgoing = DnsOutgoing::new(FLAGS_QR_QUERY); + outgoing.add_question("test_packet_size", TYPE_PTR); + + let packet = outgoing.to_packets().remove(0); + println!("packet size: {}", packet.size); + let data = packet.data.concat(); + println!("data size: {}", data.len()); + + assert_eq!(packet.size, data.len()); + } + + #[test] + fn test_querier_known_answer_multi_packet() { + let mut query = DnsOutgoing::new(FLAGS_QR_QUERY); + let name = "test_multi_packet._udp.local."; + query.add_question(name, TYPE_PTR); + + let known_answer_count = 400; + for i in 0..known_answer_count { + let alias = format!("instance{}.{}", i, name); + let answer = DnsPointer::new(name, TYPE_PTR, CLASS_IN, 0, alias); + query.add_additional_answer(answer); + } + + let mut packets = query.to_data_on_wire(); + println!("packets count: {}", packets.len()); + assert_eq!(packets.len(), 2); + + let first_packet = packets.remove(0); + println!("first packet size: {}", first_packet.len()); + + let incoming1 = DnsIncoming::new(first_packet).unwrap(); + println!( + "first packet know answer count: {}, question count: {}", + incoming1.num_additionals, incoming1.num_questions + ); + + let second_packet = packets.remove(0); + println!("second packet size: {}", second_packet.len()); + + let incoming2 = DnsIncoming::new(second_packet).unwrap(); + println!( + "second packet known answer count: {}, question count: {}", + incoming2.num_additionals, incoming2.num_questions + ); + + assert_eq!( + incoming1.num_additionals + incoming2.num_additionals, + known_answer_count + ); + + assert_eq!(incoming1.num_questions, 1); + assert_eq!(incoming2.num_questions, 0); + } } diff --git a/src/service_daemon.rs b/src/service_daemon.rs index 454539b..2b332a6 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -698,9 +698,11 @@ fn new_socket_bind(intf: &Interface) -> Result { // Test if we can send packets successfully. let multicast_addr = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT).into(); - let test_packet = DnsOutgoing::new(0).to_packet_data(); - sock.send_to(&test_packet, &multicast_addr) - .map_err(|e| e_fmt!("send multicast packet on addr {}: {}", ip, e))?; + let test_packets = DnsOutgoing::new(0).to_data_on_wire(); + for packet in test_packets { + sock.send_to(&packet, &multicast_addr) + .map_err(|e| e_fmt!("send multicast packet on addr {}: {}", ip, e))?; + } Ok(sock) } IpAddr::V6(ip) => { @@ -1290,7 +1292,7 @@ impl Zeroconf { ); } - broadcast_dns_on_intf(&out, intf_sock); + send_dns_outgoing(&out, intf_sock); true } @@ -1351,7 +1353,8 @@ impl Zeroconf { ); } - broadcast_dns_on_intf(&out, intf_sock) + // `out` data is non-empty, hence we can do this. + send_dns_outgoing(&out, intf_sock).remove(0) } /// Binds a channel `listener` to querying mDNS domain type `ty`. @@ -1400,6 +1403,7 @@ impl Zeroconf { } } + // Send the query on one interface per subnet. let mut subnet_set: HashSet = HashSet::new(); for (_, intf_sock) in self.intf_socks.iter() { if !intf_sock.intf.is_link_local() { @@ -1409,7 +1413,7 @@ impl Zeroconf { } subnet_set.insert(subnet); } - broadcast_dns_on_intf(&out, intf_sock); + send_dns_outgoing(&out, intf_sock); } } @@ -1948,7 +1952,7 @@ impl Zeroconf { if !out.answers.is_empty() { out.id = msg.id; - broadcast_dns_on_intf(&out, intf_sock); + send_dns_outgoing(&out, intf_sock); self.increase_counter(Counter::Respond, 1); } @@ -2136,7 +2140,7 @@ impl Zeroconf { fn exec_command_unregister_resend(&mut self, packet: Vec, ip: IpAddr) { if let Some(intf_sock) = self.intf_socks.get(&ip) { debug!("UnregisterResend from {}", &ip); - broadcast_on_intf(&packet[..], intf_sock); + multicast_on_intf(&packet[..], intf_sock); self.increase_counter(Counter::UnregisterResend, 1); } } @@ -2486,27 +2490,29 @@ fn my_ip_interfaces() -> Vec { .collect() } -/// Send an outgoing broadcast DNS query or response, and returns the packet bytes. -fn broadcast_dns_on_intf(out: &DnsOutgoing, intf: &IntfSock) -> Vec { +/// Send an outgoing mDNS query or response, and returns the packet bytes. +fn send_dns_outgoing(out: &DnsOutgoing, intf: &IntfSock) -> Vec> { let qtype = if out.is_query() { "query" } else { "response" }; debug!( - "Broadcasting {}: {} questions {} answers {} authorities {} additional", + "Multicasting {}: {} questions {} answers {} authorities {} additional", qtype, out.questions.len(), out.answers.len(), out.authorities.len(), out.additionals.len() ); - let packet = out.to_packet_data(); - broadcast_on_intf(&packet[..], intf); - packet + let packet_list = out.to_data_on_wire(); + for packet in packet_list.iter() { + multicast_on_intf(packet, intf); + } + packet_list } -/// Sends an outgoing broadcast packet, and returns the packet bytes. -fn broadcast_on_intf<'a>(packet: &'a [u8], intf: &IntfSock) -> &'a [u8] { +/// Sends a multicast packet, and returns the packet bytes. +fn multicast_on_intf(packet: &[u8], intf: &IntfSock) { if packet.len() > MAX_MSG_ABSOLUTE { error!("Drop over-sized packet ({})", packet.len()); - return &[]; + return; } let sock: SocketAddr = match intf.intf.addr { @@ -2519,7 +2525,6 @@ fn broadcast_on_intf<'a>(packet: &'a [u8], intf: &IntfSock) -> &'a [u8] { }; send_packet(packet, sock, intf); - packet } /// Sends out `packet` to `addr` on the socket in `intf_sock`. @@ -2544,8 +2549,8 @@ fn valid_instance_name(name: &str) -> bool { #[cfg(test)] mod tests { use super::{ - broadcast_dns_on_intf, check_domain_suffix, check_service_name_length, my_ip_interfaces, - new_socket_bind, valid_instance_name, HostnameResolutionEvent, IntfSock, ServiceDaemon, + check_domain_suffix, check_service_name_length, my_ip_interfaces, new_socket_bind, + send_dns_outgoing, valid_instance_name, HostnameResolutionEvent, IntfSock, ServiceDaemon, ServiceEvent, ServiceInfo, GROUP_ADDR_V4, MDNS_PORT, }; use crate::{ @@ -2698,7 +2703,7 @@ mod tests { intf: intf.clone(), sock: new_socket_bind(&intf).unwrap(), }; - broadcast_dns_on_intf(&packet_buffer, &intf_sock); + send_dns_outgoing(&packet_buffer, &intf_sock); } println!(