diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 6a6c6d5..788bba7 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -182,9 +182,7 @@ fn main() -> Result<()> { let ip = resolve_ip_or_domain(&args.ip, &mut extra_options)?; #[cfg(feature = "packet_capture")] - unsafe { - gamedig::capture::simple_setup_capture(args.capture.clone()); - } + gamedig::capture::setup_capture(args.capture); // Query the server using game definition, parsed IP, and user command line // flags. diff --git a/crates/lib/Cargo.toml b/crates/lib/Cargo.toml index 6b17d0c..8cc7cdb 100644 --- a/crates/lib/Cargo.toml +++ b/crates/lib/Cargo.toml @@ -17,13 +17,13 @@ rust-version = "1.65.0" categories = ["parser-implementations", "parsing", "network-programming", "encoding"] [features] -default = ["games", "services", "game_defs"] +default = ["games", "services", "game_defs", "packet_capture",] games = [] services = [] game_defs = ["dep:phf", "games"] serde = ["dep:serde", "serde/derive"] clap = ["dep:clap"] -packet_capture = ["dep:pcap-file", "dep:pnet_packet"] +packet_capture = ["dep:pcap-file", "dep:pnet_packet", "dep:lazy_static"] [dependencies] byteorder = "1.5" @@ -40,6 +40,7 @@ clap = { version = "4.1.11", optional = true, features = ["derive"] } pcap-file = { version = "2.0", optional = true } pnet_packet = { version = "0.34", optional = true } +lazy_static = { version = "1.4", optional = true } # Examples [[example]] diff --git a/crates/lib/src/capture.rs b/crates/lib/src/capture.rs deleted file mode 100644 index 9c7a7fa..0000000 --- a/crates/lib/src/capture.rs +++ /dev/null @@ -1,453 +0,0 @@ -use std::io::Write; -use std::net::{IpAddr, SocketAddr}; - -use crate::GDResult; - -use pcap_file::pcapng::{blocks::enhanced_packet::EnhancedPacketOption, PcapNgBlock}; -use pnet_packet::{ - ethernet::{EtherType, MutableEthernetPacket}, - ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, - ipv4::MutableIpv4Packet, - ipv6::MutableIpv6Packet, - tcp::{MutableTcpPacket, TcpFlags}, - udp::MutableUdpPacket, - PacketSize, -}; - -/// Info about a packet we have sent or recieved. -#[derive(Clone, Debug, PartialEq)] -pub struct PacketInfo<'a> { - pub direction: PacketDirection, - pub protocol: PacketProtocol, - pub remote_address: &'a SocketAddr, - pub local_address: &'a SocketAddr, -} - -/// The direction of a packet. -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum PacketDirection { - /// The packet is coming from us, destined for a server. - Send, - /// A server has sent this packet to us. - Receive, -} - -/// The protocol of a packet. -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum PacketProtocol { - TCP, - UDP, -} - -/// Trait for objects that can write packet captures. -pub trait CaptureWriter { - fn write(&mut self, packet: &PacketInfo, data: &[u8]) -> crate::GDResult<()>; - fn new_connect(&mut self, packet: &PacketInfo) -> crate::GDResult<()>; - // TODO: Tcp FIN when socket ends -} - -// Packet size constants -const PACKET_SIZE: usize = 5012; -const HEADER_SIZE_ETHERNET: usize = 14; -const HEADER_SIZE_IP4: usize = 20; -const HEADER_SIZE_IP6: usize = 40; -const HEADER_SIZE_UDP: usize = 4; - -/// A writer that does nothing -struct NullWriter; -impl CaptureWriter for NullWriter { - fn write(&mut self, _: &PacketInfo, _: &[u8]) -> GDResult<()> { Ok(()) } - fn new_connect(&mut self, _: &PacketInfo) -> GDResult<()> { Ok(()) } -} - -/// Writer that writes to pcap file -struct PcapWriter { - writer: pcap_file::pcapng::PcapNgWriter, - start_time: std::time::Instant, - send_seq: u32, - rec_seq: u32, - has_sent_handshake: bool, - stream_count: u32, -} -impl PcapWriter { - fn new(writer: pcap_file::pcapng::PcapNgWriter) -> Self { - Self { - writer, - start_time: std::time::Instant::now(), - send_seq: 0, - rec_seq: 0, - has_sent_handshake: false, - stream_count: 0, - } - } -} - -impl CaptureWriter for PcapWriter { - fn write(&mut self, info: &PacketInfo, data: &[u8]) -> GDResult<()> { - self.write_transport_packet(info, data); - - Ok(()) - } - - fn new_connect(&mut self, packet: &PacketInfo) -> GDResult<()> { - match packet.protocol { - PacketProtocol::TCP => { - self.write_tcp_handshake(packet); - } - PacketProtocol::UDP => {} - } - - self.stream_count = self.stream_count.wrapping_add(1); - - Ok(()) - } -} - -impl PcapWriter { - /// Encode the transport layer packet with a payload and write it. - fn write_transport_packet(&mut self, info: &PacketInfo, payload: &[u8]) { - let mut buf = vec![0; PACKET_SIZE - usize::max(HEADER_SIZE_IP4, HEADER_SIZE_IP6) - HEADER_SIZE_ETHERNET]; - - let (source_port, dest_port) = match info.direction { - PacketDirection::Send => (info.local_address.port(), info.remote_address.port()), - PacketDirection::Receive => (info.remote_address.port(), info.local_address.port()), - }; - - match info.protocol { - PacketProtocol::TCP => { - let buf_size = { - let mut tcp = MutableTcpPacket::new(&mut buf).unwrap(); - tcp.set_source(source_port); - tcp.set_destination(dest_port); - tcp.set_payload(payload); - tcp.set_data_offset(5); - tcp.set_window(43440); - match info.direction { - PacketDirection::Send => { - tcp.set_sequence(self.send_seq); - tcp.set_acknowledgement(self.rec_seq); - - self.send_seq = self.send_seq.wrapping_add(payload.len() as u32); - } - PacketDirection::Receive => { - tcp.set_sequence(self.rec_seq); - tcp.set_acknowledgement(self.send_seq); - - self.rec_seq = self.rec_seq.wrapping_add(payload.len() as u32); - } - } - tcp.set_flags(TcpFlags::PSH | TcpFlags::ACK); - - tcp.packet_size() - }; - - self.write_transport_payload( - info, - IpNextHeaderProtocols::Tcp, - &buf[.. buf_size + payload.len()], - vec![], - ); - - let mut info = info.clone(); - let buf_size = { - let mut tcp = MutableTcpPacket::new(&mut buf).unwrap(); - tcp.set_source(dest_port); - tcp.set_destination(source_port); - tcp.set_data_offset(5); - tcp.set_window(43440); - match &info.direction { - PacketDirection::Send => { - tcp.set_sequence(self.rec_seq); - tcp.set_acknowledgement(self.send_seq); - - info.direction = PacketDirection::Receive; - } - PacketDirection::Receive => { - tcp.set_sequence(self.send_seq); - tcp.set_acknowledgement(self.rec_seq); - - info.direction = PacketDirection::Send; - } - } - tcp.set_flags(TcpFlags::ACK); - - tcp.packet_size() - }; - - self.write_transport_payload( - &info, - IpNextHeaderProtocols::Tcp, - &buf[.. buf_size], - vec![EnhancedPacketOption::Comment("Generated TCP ack".into())], - ); - } - PacketProtocol::UDP => { - let buf_size = { - let mut udp = MutableUdpPacket::new(&mut buf).unwrap(); - udp.set_source(source_port); - udp.set_destination(dest_port); - udp.set_length((payload.len() + HEADER_SIZE_UDP) as u16); - udp.set_payload(payload); - - udp.packet_size() - }; - - self.write_transport_payload( - info, - IpNextHeaderProtocols::Udp, - &buf[.. buf_size + payload.len()], - vec![], - ); - } - } - } - - /// Encode a network layer (IP) packet with a payload. - fn encode_ip_packet( - &self, - buf: &mut [u8], - info: &PacketInfo, - protocol: IpNextHeaderProtocol, - payload: &[u8], - ) -> (usize, EtherType) { - match (info.local_address.ip(), info.remote_address.ip()) { - (IpAddr::V4(local_address), IpAddr::V4(remote_address)) => { - let (source, destination) = if info.direction == PacketDirection::Send { - (local_address, remote_address) - } else { - (remote_address, local_address) - }; - - let header_size = HEADER_SIZE_IP4 + (32 / 8); - - let mut ip = MutableIpv4Packet::new(buf).unwrap(); - ip.set_version(4); - ip.set_total_length((payload.len() + header_size) as u16); - ip.set_next_level_protocol(protocol); - // https://en.wikipedia.org/wiki/Internet_Protocol_version_4#Total_Length - - ip.set_header_length((header_size / 4) as u8); - ip.set_source(source); - ip.set_destination(destination); - ip.set_payload(payload); - ip.set_ttl(64); - ip.set_flags(pnet_packet::ipv4::Ipv4Flags::DontFragment); - - let mut options_writer = - pnet_packet::ipv4::MutableIpv4OptionPacket::new(ip.get_options_raw_mut()).unwrap(); - options_writer.set_copied(1); - options_writer.set_class(0); - options_writer.set_number(pnet_packet::ipv4::Ipv4OptionNumbers::SID); - options_writer.set_length(&[4]); - options_writer.set_data(&(self.stream_count as u16).to_be_bytes()); - - ip.set_checksum(pnet_packet::ipv4::checksum(&ip.to_immutable())); - - (ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv4) - } - (IpAddr::V6(local_address), IpAddr::V6(remote_address)) => { - let (source, destination) = match info.direction { - PacketDirection::Send => (local_address, remote_address), - PacketDirection::Receive => (remote_address, local_address), - }; - - let mut ip = MutableIpv6Packet::new(buf).unwrap(); - ip.set_version(6); - ip.set_payload_length(payload.len() as u16); - ip.set_next_header(protocol); - ip.set_source(source); - ip.set_destination(destination); - ip.set_hop_limit(64); - ip.set_payload(payload); - ip.set_flow_label(self.stream_count); - - (ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv6) - } - _ => unreachable!(), - } - } - - /// Encode a physical layer (ethernet) packet with a payload. - fn encode_ethernet_packet( - &self, - buf: &mut [u8], - ethertype: pnet_packet::ethernet::EtherType, - payload: &[u8], - ) -> usize { - let mut ethernet = MutableEthernetPacket::new(buf).unwrap(); - ethernet.set_ethertype(ethertype); - ethernet.set_payload(payload); - - ethernet.packet_size() - } - - /// Write a TCP handshake. - fn write_tcp_handshake(&mut self, info: &PacketInfo) { - let (source_port, dest_port) = (info.local_address.port(), info.remote_address.port()); - - let mut info = info.clone(); - info.direction = PacketDirection::Send; - let mut buf = vec![0; PACKET_SIZE]; - // Add a generated comment to all packets - let options = vec![ - pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketOption::Comment("Generated TCP handshake".into()), - ]; - - // SYN - let buf_size = { - let mut tcp = MutableTcpPacket::new(&mut buf).unwrap(); - self.send_seq = 500; - tcp.set_sequence(self.send_seq); - tcp.set_flags(TcpFlags::SYN); - tcp.set_source(source_port); - tcp.set_destination(dest_port); - tcp.set_window(43440); - tcp.set_data_offset(5); - - tcp.packet_size() - }; - self.write_transport_payload( - &info, - IpNextHeaderProtocols::Tcp, - &buf[.. buf_size], - options.clone(), - ); - - // SYN + ACK - info.direction = PacketDirection::Receive; - let buf_size = { - let mut tcp = MutableTcpPacket::new(&mut buf).unwrap(); - self.send_seq = self.send_seq.wrapping_add(1); - tcp.set_acknowledgement(self.send_seq); - self.rec_seq = 1000; - tcp.set_sequence(self.rec_seq); - tcp.set_flags(TcpFlags::SYN | TcpFlags::ACK); - tcp.set_source(dest_port); - tcp.set_destination(source_port); - tcp.set_window(43440); - tcp.set_data_offset(5); - - tcp.packet_size() - }; - self.write_transport_payload( - &info, - IpNextHeaderProtocols::Tcp, - &buf[.. buf_size], - options.clone(), - ); - - // ACK - info.direction = PacketDirection::Send; - let buf_size = { - let mut tcp = MutableTcpPacket::new(&mut buf).unwrap(); - tcp.set_sequence(self.send_seq); - self.rec_seq = self.rec_seq.wrapping_add(1); - tcp.set_acknowledgement(self.rec_seq); - tcp.set_flags(TcpFlags::ACK); - tcp.set_source(source_port); - tcp.set_destination(dest_port); - tcp.set_window(43440); - tcp.set_data_offset(5); - - tcp.packet_size() - }; - self.write_transport_payload( - &info, - IpNextHeaderProtocols::Tcp, - &buf[.. buf_size], - options, - ); - - self.has_sent_handshake = true; - } - - /// Take a transport layer packet as a buffer and write it after encoding - /// all the layers under it. - fn write_transport_payload( - &mut self, - info: &PacketInfo, - protocol: IpNextHeaderProtocol, - payload: &[u8], - options: Vec, - ) { - let mut network_packet = vec![0; PACKET_SIZE - HEADER_SIZE_ETHERNET]; - let (network_size, ethertype) = self.encode_ip_packet(&mut network_packet, info, protocol, payload); - let network_size = network_size + payload.len(); - network_packet.truncate(network_size); - - let mut physical_packet = vec![0; PACKET_SIZE]; - let physical_size = - self.encode_ethernet_packet(&mut physical_packet, ethertype, &network_packet) + network_size; - - physical_packet.truncate(physical_size); - - self.writer - .write_block( - &pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketBlock { - original_len: physical_size as u32, - data: physical_packet.into(), - interface_id: 0, - timestamp: self.start_time.elapsed(), - options, - } - .into_block(), - ) - .unwrap(); - } -} - -/// Setup the static capture into a file or to nowhere. -/// -/// This leaks the writer. -/// -/// # Panics -/// - If this is called more than once (OnceLock used internally). -/// -/// # Safety -/// The safety of this function has not been evaluated yet, and -/// testing has only been done with limited CLI use cases. -pub unsafe fn simple_setup_capture(file_name: Option) { - let writer: Box = if let Some(file_name) = file_name { - let file = std::fs::OpenOptions::new() - .create_new(true) - .write(true) - .open(file_name) - .unwrap(); - let mut writer = pcap_file::pcapng::PcapNgWriter::new(file).unwrap(); - - // Write headers - writer - .write_block( - &pcap_file::pcapng::blocks::interface_description::InterfaceDescriptionBlock { - linktype: pcap_file::DataLink::ETHERNET, - snaplen: 0xFFFF, - options: vec![], - } - .into_block(), - ) - .unwrap(); - - let writer = PcapWriter::new(writer); - Box::new(writer) - } else { - Box::new(NullWriter) - }; - setup_capture(writer); -} - -/// Set a capture writer to handle packet send/recieve data. -/// -/// This leaks the writer. -/// -/// # Panics -/// - If this is called more than once (OnceLock used internally). -/// -/// # Safety -/// The safety of this function has not been evaluated yet, and -/// testing has only been done with limited CLI use cases. -pub unsafe fn setup_capture(writer: Box) { - // TODO: safety - unsafe { - crate::socket::capture::set_writer(writer); - } -} diff --git a/crates/lib/src/capture/mod.rs b/crates/lib/src/capture/mod.rs new file mode 100644 index 0000000..b659dca --- /dev/null +++ b/crates/lib/src/capture/mod.rs @@ -0,0 +1,43 @@ +pub(crate) mod packet; +mod pcap; +pub mod writer; + +use pcap_file::pcapng::PcapNgBlock; +use writer::Writer; + +use self::pcap::Pcap; + +pub fn setup_capture(file_name: Option) { + if let Some(file_name) = file_name { + let file = std::fs::OpenOptions::new() + .create_new(true) + .write(true) + .open(file_name) + .unwrap(); + + let mut pcap_writer = pcap_file::pcapng::PcapNgWriter::new(file).unwrap(); + + // Write headers + pcap_writer.write_block( + &pcap_file::pcapng::blocks::interface_description::InterfaceDescriptionBlock { + linktype: pcap_file::DataLink::ETHERNET, + snaplen: 0xFFFF, + options: vec![], + } + .into_block(), + ); + + let writer = Box::new(Pcap::new(pcap_writer)); + attach(writer) + } else { + // Do nothing + } +} + +/// Attaches a writer to the capture module. +/// +/// # Errors +/// Returns an `io::Error` if the writer is already set. +fn attach(writer: Box) { + crate::socket::capture::set_writer(writer); +} diff --git a/crates/lib/src/capture/packet.rs b/crates/lib/src/capture/packet.rs new file mode 100644 index 0000000..4f726fc --- /dev/null +++ b/crates/lib/src/capture/packet.rs @@ -0,0 +1,254 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +/// Size of a standard network packet. +pub(crate) const PACKET_SIZE: usize = 5012; +/// Size of an Ethernet header. +pub(crate) const HEADER_SIZE_ETHERNET: usize = 14; +/// Size of an IPv4 header. +pub(crate) const HEADER_SIZE_IP4: usize = 20; +/// Size of an IPv6 header. +pub(crate) const HEADER_SIZE_IP6: usize = 40; +/// Size of a UDP header. +pub(crate) const HEADER_SIZE_UDP: usize = 4; + +/// Represents the direction of a network packet. +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) enum Direction { + /// Packet is outgoing (sent by us). + Send, + /// Packet is incoming (received by us). + Receive, +} + +/// Defines the protocol of a network packet. +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) enum Protocol { + /// Transmission Control Protocol. + TCP, + /// User Datagram Protocol. + UDP, +} + +/// Trait for handling different types of IP addresses (IPv4, IPv6). +pub trait IpAddress: Sized { + /// Creates an instance from a standard `IpAddr`, returning `None` if the types are incompatible. + fn from_std(ip: IpAddr) -> Option; +} + +/// Represents a captured network packet with metadata. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct CapturePacket<'a> { + /// Direction of the packet (Send/Receive). + pub(crate) direction: Direction, + /// Protocol of the packet (TCP/UDP). + pub(crate) protocol: Protocol, + /// Remote socket address. + pub(crate) remote_address: &'a SocketAddr, + /// Local socket address. + pub(crate) local_address: &'a SocketAddr, +} + +impl CapturePacket<'_> { + /// Retrieves the local and remote ports based on the packet's direction. + /// + /// Returns: + /// - (u16, u16): Tuple of (source port, destination port). + pub(crate) const fn ports_by_direction(&self) -> (u16, u16) { + let (local, remote) = (self.local_address.port(), self.remote_address.port()); + self.direction.order(local, remote) + } + + /// Retrieves the local and remote IP addresses. + /// + /// Returns: + /// - (IpAddr, IpAddr): Tuple of (local IP, remote IP). + pub(crate) fn ip_addr(&self) -> (IpAddr, IpAddr) { + let (local, remote) = (self.local_address.ip(), self.remote_address.ip()); + (local, remote) + } + + /// Retrieves IP addresses based on the packet's direction. + /// + /// Returns: + /// - (IpAddr, IpAddr): Tuple of (source IP, destination IP). + pub(crate) fn ip_addr_by_direction(&self) -> (IpAddr, IpAddr) { + let (local, remote) = self.ip_addr(); + self.direction.order(local, remote) + } + + /// Retrieves IP addresses of a specific type (IPv4 or IPv6) based on the packet's direction. + /// + /// Panics if the IP type of the addresses does not match the requested type. + /// + /// Returns: + /// - (T, T): Tuple of (source IP, destination IP) of the specified type in order. + pub(crate) fn ipvt_by_direction(&self) -> (T, T) { + let (local, remote) = ( + T::from_std(self.local_address.ip()).expect("Incorrect IP type for local address"), + T::from_std(self.remote_address.ip()).expect("Incorrect IP type for remote address"), + ); + + self.direction.order(local, remote) + } +} + +impl Direction { + /// Orders two elements (source and destination) based on the packet's direction. + /// + /// Returns: + /// - (T, T): Ordered tuple (source, destination). + pub(crate) const fn order(&self, source: T, remote: T) -> (T, T) { + match self { + Direction::Send => (source, remote), + Direction::Receive => (remote, source), + } + } +} + +/// Implements the `IpAddress` trait for `Ipv4Addr`. +impl IpAddress for Ipv4Addr { + /// Creates an `Ipv4Addr` from a standard `IpAddr`, if it's IPv4. + fn from_std(ip: IpAddr) -> Option { + match ip { + IpAddr::V4(ipv4) => Some(ipv4), + _ => None, + } + } +} + +/// Implements the `IpAddress` trait for `Ipv6Addr`. +impl IpAddress for Ipv6Addr { + /// Creates an `Ipv6Addr` from a standard `IpAddr`, if it's IPv6. + fn from_std(ip: IpAddr) -> Option { + match ip { + IpAddr::V6(ipv6) => Some(ipv6), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + // Helper function to create a SocketAddr from a string + fn socket_addr(addr: &str) -> SocketAddr { + SocketAddr::from_str(addr).unwrap() + } + + #[test] + fn test_ports_by_direction() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + assert_eq!(packet_send.ports_by_direction(), (8080, 80)); + assert_eq!(packet_receive.ports_by_direction(), (80, 8080)); + } + + #[test] + fn test_ip_addr_by_direction_ipv4() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("10.0.0.1:3000"), + remote_address: &socket_addr("10.0.0.2:3001"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::UDP, + local_address: &socket_addr("10.0.0.1:3000"), + remote_address: &socket_addr("10.0.0.2:3001"), + }; + + assert_eq!( + packet_send.ip_addr_by_direction(), + ( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)) + ) + ); + assert_eq!( + packet_receive.ip_addr_by_direction(), + ( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)) + ) + ); + } + + #[test] + fn test_ip_addr_by_direction_ipv6() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("[::1]:3000"), + remote_address: &socket_addr("[::2]:3001"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::UDP, + local_address: &socket_addr("[::1]:3000"), + remote_address: &socket_addr("[::2]:3001"), + }; + + assert_eq!( + packet_send.ip_addr_by_direction(), + ( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)) + ) + ); + assert_eq!( + packet_receive.ip_addr_by_direction(), + ( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)) + ) + ); + } + + #[test] + fn test_ip_by_direction_type_specific() { + let packet = CapturePacket { + direction: Direction::Send, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + let ipv4_result: Result<(Ipv4Addr, Ipv4Addr), _> = + std::panic::catch_unwind(|| packet.ipvt_by_direction::()); + assert!(ipv4_result.is_ok()); + + let ipv6_result: Result<(Ipv6Addr, Ipv6Addr), _> = + std::panic::catch_unwind(|| packet.ipvt_by_direction::()); + assert!(ipv6_result.is_err()); + } + + #[test] + #[should_panic(expected = "Local and remote IP addresses must be of the same version")] + fn test_mismatched_ip_version_panic() { + let packet = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("127.0.0.1:8080"), // IPv4 + remote_address: &socket_addr("[::1]:80"), // IPv6 + }; + + packet.ip_addr_by_direction(); + } +} diff --git a/crates/lib/src/capture/pcap.rs b/crates/lib/src/capture/pcap.rs new file mode 100644 index 0000000..054c8e3 --- /dev/null +++ b/crates/lib/src/capture/pcap.rs @@ -0,0 +1,311 @@ +use pcap_file::pcapng::{ + blocks::enhanced_packet::{EnhancedPacketBlock, EnhancedPacketOption}, + PcapNgBlock, PcapNgWriter, +}; +use pnet_packet::{ + ethernet::{EtherType, EtherTypes, MutableEthernetPacket}, + ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, + ipv4::MutableIpv4Packet, + ipv6::MutableIpv6Packet, + tcp::{MutableTcpPacket, TcpFlags}, + udp::MutableUdpPacket, + PacketSize, +}; +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::Instant, +}; + +use super::packet::{ + CapturePacket, Direction, Protocol, HEADER_SIZE_ETHERNET, HEADER_SIZE_IP4, HEADER_SIZE_IP6, HEADER_SIZE_UDP, + PACKET_SIZE, +}; + +const DEFAULT_TTL: u8 = 64; +const TCP_WINDOW_SIZE: u16 = 43440; +const BUFFER_SIZE: usize = PACKET_SIZE + - (if HEADER_SIZE_IP4 > HEADER_SIZE_IP6 { + HEADER_SIZE_IP4 + } else { + HEADER_SIZE_IP6 + }) + - HEADER_SIZE_ETHERNET; + +pub(crate) struct Pcap { + writer: PcapNgWriter, + pub(crate) state: State, + buffer: Vec, +} + +pub(crate) struct State { + pub(crate) start_time: Instant, + pub(crate) send_seq: u32, + pub(crate) rec_seq: u32, + pub(crate) has_sent_handshake: bool, + pub(crate) has_sent_fin: bool, + pub(crate) stream_count: u32, +} + +impl Pcap { + pub fn new(writer: PcapNgWriter) -> Self { + Self { + writer, + state: State::default(), + buffer: vec![0; BUFFER_SIZE], + } + } + + pub fn write_transport_packet(&mut self, info: &CapturePacket, payload: &[u8]) -> Result<(), std::io::Error> { + let (source_port, dest_port) = info.ports_by_direction(); + + match info.protocol { + Protocol::TCP => self.handle_tcp(info, payload, source_port, dest_port)?, + Protocol::UDP => self.handle_udp(info, payload, source_port, dest_port)?, + } + + Ok(()) + } + + fn handle_tcp( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result<(), std::io::Error> { + let buf_size = self.setup_tcp_packet(info, payload, source_port, dest_port)?; + self.write_transport_payload( + info, + IpNextHeaderProtocols::Tcp, + &self.buffer[..buf_size + payload.len()], + vec![], + ); + + Ok(()) + } + + fn setup_tcp_packet( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result { + let mut tcp = MutableTcpPacket::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create TCP packet"))?; + tcp.set_source(source_port); + tcp.set_destination(dest_port); + tcp.set_payload(payload); + tcp.set_data_offset(5); + tcp.set_window(TCP_WINDOW_SIZE); + + // Set sequence and acknowledgement numbers + match info.direction { + Direction::Send => { + tcp.set_sequence(self.state.send_seq); + tcp.set_acknowledgement(self.state.rec_seq); + self.state.send_seq = self.state.send_seq.wrapping_add(payload.len() as u32); + } + Direction::Receive => { + tcp.set_sequence(self.state.rec_seq); + tcp.set_acknowledgement(self.state.send_seq); + self.state.rec_seq = self.state.rec_seq.wrapping_add(payload.len() as u32); + } + } + tcp.set_flags(TcpFlags::PSH | TcpFlags::ACK); + + Ok(tcp.packet_size()) + } + + pub fn write_tcp_handshake(&mut self, info: &CapturePacket) { + // Initialize sequence numbers for demonstration purposes + self.state.send_seq = 500; + self.state.rec_seq = 1000; + + // Common setup for TCP handshake packets + let mut tcp_handshake_packet = + |info: &CapturePacket, direction: Direction, flags: u8| -> Result<(), std::io::Error> { + let (source_port, dest_port) = info.ports_by_direction(); + let adjusted_info = CapturePacket { + direction, + ..info.clone() + }; + self.setup_tcp_packet(&adjusted_info, &[], source_port, dest_port)?; + Ok(self.write_transport_payload( + &adjusted_info, + IpNextHeaderProtocols::Tcp, + &self.buffer, + vec![EnhancedPacketOption::Comment( + format!( + "Generated TCP {}", + match flags { + TcpFlags::SYN => "SYN", + TcpFlags::SYN | TcpFlags::ACK => "SYN-ACK", + TcpFlags::ACK => "ACK", + } + ) + .into(), + )], + )) + }; + + // Send SYN + tcp_handshake_packet(info, Direction::Send, TcpFlags::SYN); + + // Send SYN-ACK + self.state.send_seq = self.state.send_seq.wrapping_add(1); // Update sequence number after SYN + tcp_handshake_packet(info, Direction::Receive, TcpFlags::SYN | TcpFlags::ACK); + + // Send ACK + self.state.rec_seq = self.state.rec_seq.wrapping_add(1); // Update sequence number after SYN-ACK + tcp_handshake_packet(info, Direction::Send, TcpFlags::ACK); + } + + fn handle_udp( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result<(), std::io::Error> { + let mut udp = MutableUdpPacket::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create UDP packet"))?; + udp.set_source(source_port); + udp.set_destination(dest_port); + udp.set_length((payload.len() + HEADER_SIZE_UDP) as u16); + udp.set_payload(payload); + + let buf_size = udp.packet_size(); + self.write_transport_payload( + info, + IpNextHeaderProtocols::Udp, + &self.buffer[..buf_size + payload.len()], + vec![], + ); + + Ok(()) + } + + fn write_transport_payload( + &mut self, + info: &CapturePacket, + protocol: IpNextHeaderProtocol, + payload: &[u8], + options: Vec, + ) { + let network_packet_size = self.encode_ip_packet(info, protocol, payload).unwrap().0; + let ethertype = self.encode_ip_packet(info, protocol, payload).unwrap().1; + let ethernet_packet_size = self.encode_ethernet_packet(info, ethertype, &self.buffer[..network_packet_size]).unwrap(); + + let enhanced_packet_block = EnhancedPacketBlock { + original_len: ethernet_packet_size as u32, + data: self.buffer[..ethernet_packet_size].to_vec().into(), + interface_id: 0, + timestamp: self.state.start_time.elapsed(), + options, + }; + + self.writer.write_block(&enhanced_packet_block.into_block()); + } + + fn encode_ethernet_packet( + &mut self, + info: &CapturePacket, + ethertype: EtherType, + payload: &[u8], + ) -> Result { + let mut ethernet_packet = MutableEthernetPacket::new(&mut self.buffer).ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to create Ethernet packet", + ) + })?; + + ethernet_packet.set_ethertype(ethertype); + ethernet_packet.set_payload(payload); + + Ok(ethernet_packet.packet_size()) + } + + fn encode_ip_packet( + &mut self, + info: &CapturePacket, + protocol: IpNextHeaderProtocol, + payload: &[u8], + ) -> Result<(usize, EtherType), std::io::Error> { + match info.ip_addr() { + (IpAddr::V4(_), IpAddr::V4(_)) => { + let (source, destination) = info.ipvt_by_direction(); + + let mut ip_packet = MutableIpv4Packet::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create IPv4 packet"))?; + self.set_ipv4_packet_fields(&mut ip_packet, source, destination, payload, protocol); + ip_packet.set_checksum(pnet_packet::ipv4::checksum(&ip_packet.to_immutable())); + + Ok((ip_packet.packet_size(), EtherTypes::Ipv4)) + } + (IpAddr::V6(_), IpAddr::V6(_)) => { + let (source, destination) = info.ipvt_by_direction(); + + let mut ip_packet = MutableIpv6Packet::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create IPv6 packet"))?; + self.set_ipv6_packet_fields(&mut ip_packet, source, destination, payload, protocol); + + Ok((ip_packet.packet_size(), EtherTypes::Ipv6)) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unsupported or mismatched IP address types", + )), + } + } + + fn set_ipv4_packet_fields( + &mut self, + ip_packet: &mut MutableIpv4Packet, + source: Ipv4Addr, + destination: Ipv4Addr, + payload: &[u8], + protocol: IpNextHeaderProtocol, + ) { + ip_packet.set_version(4); + ip_packet.set_header_length(5); // No options + ip_packet.set_total_length((payload.len() + HEADER_SIZE_IP4) as u16); + ip_packet.set_next_level_protocol(protocol); + ip_packet.set_source(source); + ip_packet.set_destination(destination); + ip_packet.set_ttl(DEFAULT_TTL); + ip_packet.set_payload(payload); + } + + fn set_ipv6_packet_fields( + &mut self, + ip_packet: &mut MutableIpv6Packet, + source: Ipv6Addr, + destination: Ipv6Addr, + payload: &[u8], + protocol: IpNextHeaderProtocol, + ) { + ip_packet.set_version(6); + ip_packet.set_payload_length(payload.len() as u16); + ip_packet.set_next_header(protocol); + ip_packet.set_source(source); + ip_packet.set_destination(destination); + ip_packet.set_hop_limit(DEFAULT_TTL); + ip_packet.set_payload(payload); + } +} + +impl Default for State { + fn default() -> Self { + Self { + start_time: Instant::now(), + send_seq: 0, + rec_seq: 0, + has_sent_handshake: false, + has_sent_fin: false, + stream_count: 0, + } + } +} diff --git a/crates/lib/src/capture/writer.rs b/crates/lib/src/capture/writer.rs new file mode 100644 index 0000000..012b700 --- /dev/null +++ b/crates/lib/src/capture/writer.rs @@ -0,0 +1,41 @@ +use std::io::Write; + +use crate::{ + capture::packet::{CapturePacket, Protocol}, + GDResult, +}; + +use super::pcap::Pcap; + +use lazy_static::lazy_static; +use std::sync::Mutex; + +lazy_static! { + pub(crate) static ref CAPTURE_WRITER: Mutex>> = Mutex::new(None); +} + +pub trait Writer { + fn write(&mut self, packet: &CapturePacket, data: &[u8]) -> crate::GDResult<()>; + fn new_connect(&mut self, packet: &CapturePacket) -> crate::GDResult<()>; +} + +impl Writer for Pcap { + fn write(&mut self, info: &CapturePacket, data: &[u8]) -> GDResult<()> { + self.write_transport_packet(info, data); + + Ok(()) + } + + fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()> { + match packet.protocol { + Protocol::TCP => { + self.write_tcp_handshake(packet); + } + Protocol::UDP => {} + } + + self.state.stream_count = self.state.stream_count.wrapping_add(1); + + Ok(()) + } +} diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index 38ec19c..aeb3be2 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -46,7 +46,7 @@ mod socket; mod utils; #[cfg(feature = "packet_capture")] -pub mod capture; +pub(crate) mod capture; pub use errors::*; #[cfg(feature = "games")] diff --git a/crates/lib/src/socket.rs b/crates/lib/src/socket.rs index 7dfaf19..f0fe3d3 100644 --- a/crates/lib/src/socket.rs +++ b/crates/lib/src/socket.rs @@ -12,26 +12,65 @@ use std::{ const DEFAULT_PACKET_SIZE: usize = 1024; +/// A trait defining the basic functionalities of a network socket. pub trait Socket { - /// Create a new socket and connect to the remote address (if required). + /// Create a new socket and connect to the remote address. /// - /// Calls [Self::apply_timeout] with the given timeout settings. + /// # Arguments + /// * `address` - The address to connect the socket to. + /// * `timeout_settings` - Optional timeout settings for the socket. + /// + /// # Returns + /// A result containing the socket instance or an error. fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult - where Self: Sized; + where + Self: Sized; + /// Apply read and write timeouts to the socket. + /// + /// # Arguments + /// * `timeout_settings` - Optional timeout settings to apply. + /// + /// # Returns + /// A result indicating success or error in applying timeouts. fn apply_timeout(&self, timeout_settings: &Option) -> GDResult<()>; + /// Send data over the socket. + /// + /// # Arguments + /// * `data` - Data to be sent. + /// + /// # Returns + /// A result indicating success or error in sending data. fn send(&mut self, data: &[u8]) -> GDResult<()>; + + /// Receive data from the socket. + /// + /// # Arguments + /// * `size` - Optional size of data to receive. + /// + /// # Returns + /// A result containing received data or an error. fn receive(&mut self, size: Option) -> GDResult>; - /// Get the remote port + /// Get the remote port of the socket. + /// + /// # Returns + /// The port number. fn port(&self) -> u16; - /// Get the local SocketAddr + + /// Get the local SocketAddr. + /// + /// # Returns + /// The local SocketAddr. fn local_addr(&self) -> std::io::Result; } +/// Implementation of a TCP socket. pub struct TcpSocketImpl { + /// The underlying TCP socket stream. socket: net::TcpStream, + /// The address of the remote host. address: SocketAddr, } @@ -75,12 +114,19 @@ impl Socket for TcpSocketImpl { Ok(buf) } - fn port(&self) -> u16 { self.address.port() } - fn local_addr(&self) -> std::io::Result { self.socket.local_addr() } + fn port(&self) -> u16 { + self.address.port() + } + fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() + } } +/// Implementation of a UDP socket. pub struct UdpSocketImpl { + /// The underlying UDP socket. socket: net::UdpSocket, + /// The address of the remote host. address: SocketAddr, } @@ -121,53 +167,78 @@ impl Socket for UdpSocketImpl { .recv_from(&mut buf) .map_err(|e| PacketReceive.context(e))?; - Ok(buf[.. number_of_bytes_received].to_vec()) + Ok(buf[..number_of_bytes_received].to_vec()) } - fn port(&self) -> u16 { self.address.port() } - fn local_addr(&self) -> std::io::Result { self.socket.local_addr() } + fn port(&self) -> u16 { + self.address.port() + } + fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() + } } /// Things used for capturing packets. #[cfg(feature = "packet_capture")] pub mod capture { - use std::{marker::PhantomData, net::SocketAddr, sync::OnceLock}; + use std::{marker::PhantomData, net::SocketAddr}; use super::{Socket, TcpSocketImpl, UdpSocketImpl}; use crate::{ - capture::{CaptureWriter, PacketDirection, PacketInfo, PacketProtocol}, + capture::{ + packet::CapturePacket, + packet::{Direction, Protocol}, + writer::{Writer, CAPTURE_WRITER}, + }, protocols::types::TimeoutSettings, GDResult, }; - static mut WRITER_CELL: OnceLock> = OnceLock::new(); - - /// Set a GLOBAL capture writer that handles writing all packet data. + /// Sets a global capture writer for handling all packet data. /// - /// This is unsafe because the writer is a mutable static, the caller - /// is responsible for ensuring that the writer type is able to be stored - /// as such. - pub(crate) unsafe fn set_writer(writer: Box) { - if WRITER_CELL.set(writer).is_err() { - panic!("Should only set writer once"); + /// # Panics + /// Panics if a capture writer is already set. + /// + /// # Arguments + /// * `writer` - A boxed writer that implements the `Writer` trait. + pub(crate) fn set_writer(writer: Box) { + let mut lock = CAPTURE_WRITER.lock().unwrap(); + + if lock.is_some() { + panic!("Capture writer already set"); + } + + *lock = Some(writer); + } + + /// A trait representing a provider of a network protocol. + pub trait ProtocolProvider { + /// Returns the protocol used by the provider. + fn protocol() -> Protocol; + } + + /// Represents the TCP protocol provider. + pub struct ProtocolTCP; + impl ProtocolProvider for ProtocolTCP { + fn protocol() -> Protocol { + Protocol::TCP } } - pub trait PacketProtocolProvider { - fn protocol() -> PacketProtocol; - } - pub struct PacketProtocolTCP; - impl PacketProtocolProvider for PacketProtocolTCP { - fn protocol() -> PacketProtocol { PacketProtocol::TCP } + /// Represents the UDP protocol provider. + pub struct ProtocolUDP; + impl ProtocolProvider for ProtocolUDP { + fn protocol() -> Protocol { + Protocol::UDP + } } - pub struct PacketProtocolUDP; - impl PacketProtocolProvider for PacketProtocolUDP { - fn protocol() -> PacketProtocol { PacketProtocol::UDP } - } - - /// A socket that allows capturing + /// A socket wrapper that allows capturing packets. + /// + /// # Type parameters + /// * `I` - The inner socket type. + /// * `P` - The protocol provider. #[derive(Clone, Debug)] pub struct WrappedCaptureSocket { inner: I, @@ -175,76 +246,143 @@ pub mod capture { _protocol: PhantomData

, } - impl Socket for WrappedCaptureSocket { + impl Socket for WrappedCaptureSocket { + /// Creates a new wrapped socket for capturing packets. + /// + /// Initializes a new socket of type `I`, wrapping it to enable packet capturing. + /// Capturing is protocol-specific, as indicated by the `ProtocolProvider`. + /// + /// # Arguments + /// * `address` - The address to connect the socket to. + /// * `timeout_settings` - Optional timeout settings for the socket. + /// + /// # Returns + /// A `GDResult` containing either the wrapped socket or an error. fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult - where Self: Sized { + where + Self: Sized, + { let v = Self { inner: I::new(address, timeout_settings)?, remote_address: *address, _protocol: PhantomData, }; - let info = PacketInfo { - direction: PacketDirection::Send, + let info = CapturePacket { + direction: Direction::Send, protocol: P::protocol(), remote_address: address, local_address: &v.local_addr().unwrap(), }; - // TODO: Safety - unsafe { - if let Some(writer) = WRITER_CELL.get_mut() { - writer.new_connect(&info)?; - } + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.new_connect(&info)?; } Ok(v) } + /// Sends data over the socket and captures the packet. + /// + /// The method sends data using the inner socket and captures the sent packet + /// if a capture writer is set. + /// + /// # Arguments + /// * `data` - Data to be sent. + /// + /// # Returns + /// A result indicating success or error in sending data. fn send(&mut self, data: &[u8]) -> crate::GDResult<()> { - let info = PacketInfo { - direction: PacketDirection::Send, + let info = CapturePacket { + direction: Direction::Send, protocol: P::protocol(), remote_address: &self.remote_address, local_address: &self.local_addr().unwrap(), }; - // TODO: Safety - unsafe { - if let Some(writer) = WRITER_CELL.get_mut() { - writer.write(&info, data)?; - } + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.write(&info, data)?; } + self.inner.send(data) } + /// Receives data from the socket and captures the packet. + /// + /// The method receives data using the inner socket and captures the incoming packet + /// if a capture writer is set. + /// + /// # Arguments + /// * `size` - Optional size of data to receive. + /// + /// # Returns + /// A result containing received data or an error. fn receive(&mut self, size: Option) -> crate::GDResult> { let data = self.inner.receive(size)?; - let info = PacketInfo { - direction: PacketDirection::Receive, + let info = CapturePacket { + direction: Direction::Receive, protocol: P::protocol(), remote_address: &self.remote_address, local_address: &self.local_addr().unwrap(), }; - // TODO: Safety - unsafe { - if let Some(writer) = WRITER_CELL.get_mut() { - writer.write(&info, &data)?; - } + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.write(&info, &data)?; } + Ok(data) } + /// Applies timeout settings to the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Arguments + /// * `timeout_settings` - Optional timeout settings to apply. + /// + /// # Returns + /// A result indicating success or error in applying timeouts. fn apply_timeout( &self, timeout_settings: &Option, ) -> crate::GDResult<()> { self.inner.apply_timeout(timeout_settings) } - fn port(&self) -> u16 { self.inner.port() } - fn local_addr(&self) -> std::io::Result { self.inner.local_addr() } + + /// Returns the remote port of the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Returns + /// The remote port number. + fn port(&self) -> u16 { + self.inner.port() + } + + /// Returns the local SocketAddr of the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Returns + /// The local SocketAddr. + fn local_addr(&self) -> std::io::Result { + self.inner.local_addr() + } } - pub type CapturedUdpSocket = WrappedCaptureSocket; - pub type CapturedTcpSocket = WrappedCaptureSocket; + /// A specialized `WrappedCaptureSocket` for UDP, using `UdpSocketImpl` as the inner socket + /// and `ProtocolUDP` as the protocol provider. + /// + /// This type captures and processes UDP packets, wrapping around standard UDP socket + /// functionalities with additional packet capture capabilities. + pub type CapturedUdpSocket = WrappedCaptureSocket; + + /// A specialized `WrappedCaptureSocket` for TCP, using `TcpSocketImpl` as the inner socket + /// and `ProtocolTCP` as the protocol provider. + /// + /// This type captures and processes TCP packets, wrapping around standard TCP socket + /// functionalities with additional packet capture capabilities. + pub type CapturedTcpSocket = WrappedCaptureSocket; } #[cfg(not(feature = "packet_capture"))]