diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d74c1d0..dd4ed97 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -11,8 +11,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["json"] +default = ["json", "packet_capture"] json = ["dep:serde", "dep:serde_json", "gamedig/serde"] +packet_capture = ["gamedig/packet_capture"] [dependencies] clap = { version = "4.1.11", features = ["derive"] } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 3628e8d..788bba7 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -1,7 +1,10 @@ use std::net::{IpAddr, ToSocketAddrs}; use clap::{Parser, ValueEnum}; -use gamedig::{games::*, protocols::types::CommonResponse, ExtraRequestSettings, TimeoutSettings}; +use gamedig::{ + games::*, + protocols::types::{CommonResponse, ExtraRequestSettings, TimeoutSettings}, +}; mod error; @@ -35,6 +38,10 @@ struct Cli { #[arg(short, long, default_value = "generic")] output_mode: OutputMode, + #[cfg(feature = "packet_capture")] + #[arg(short, long)] + capture: Option, + /// Optional timeout settings for the server query. #[command(flatten, next_help_heading = "Timeouts")] timeout_settings: Option, @@ -82,14 +89,13 @@ fn find_game(game_id: &str) -> Result<&'static Game> { /// * `Result` - On sucess returns a resolved IP address; on failure /// returns an [Error::InvalidHostname] error. fn resolve_ip_or_domain(host: &str, extra_options: &mut Option) -> Result { - host.parse().map_or_else( - |_| { - set_hostname_if_missing(host, extra_options); + if let Ok(parsed_ip) = host.parse() { + Ok(parsed_ip) + } else { + set_hostname_if_missing(host, extra_options); - resolve_domain(host) - }, - Ok, - ) + resolve_domain(host) + } } /// Resolve a domain name to one of its IP addresses (the first one returned). @@ -175,6 +181,9 @@ fn main() -> Result<()> { // Resolve the IP address let ip = resolve_ip_or_domain(&args.ip, &mut extra_options)?; + #[cfg(feature = "packet_capture")] + gamedig::capture::setup_capture(args.capture); + // Query the server using game definition, parsed IP, and user command line // flags. let result = query_with_timeout_and_extra_settings(game, &ip, args.port, args.timeout_settings, extra_options)?; diff --git a/crates/lib/Cargo.toml b/crates/lib/Cargo.toml index 563aab8..8cc7cdb 100644 --- a/crates/lib/Cargo.toml +++ b/crates/lib/Cargo.toml @@ -17,12 +17,13 @@ rust-version = "1.65.0" categories = ["parser-implementations", "parsing", "network-programming", "encoding"] [features] -default = ["games", "services", "game_defs"] +default = ["games", "services", "game_defs", "packet_capture",] games = [] services = [] game_defs = ["dep:phf", "games"] serde = ["dep:serde", "serde/derive"] clap = ["dep:clap"] +packet_capture = ["dep:pcap-file", "dep:pnet_packet", "dep:lazy_static"] [dependencies] byteorder = "1.5" @@ -37,8 +38,9 @@ phf = { version = "0.11", optional = true, features = ["macros"] } clap = { version = "4.1.11", optional = true, features = ["derive"] } -[dev-dependencies] -gamedig-id-tests = { path = "../id-tests", no-default-features = true } +pcap-file = { version = "2.0", optional = true } +pnet_packet = { version = "0.34", optional = true } +lazy_static = { version = "1.4", optional = true } # Examples [[example]] diff --git a/crates/lib/src/capture/mod.rs b/crates/lib/src/capture/mod.rs new file mode 100644 index 0000000..b659dca --- /dev/null +++ b/crates/lib/src/capture/mod.rs @@ -0,0 +1,43 @@ +pub(crate) mod packet; +mod pcap; +pub mod writer; + +use pcap_file::pcapng::PcapNgBlock; +use writer::Writer; + +use self::pcap::Pcap; + +pub fn setup_capture(file_name: Option) { + if let Some(file_name) = file_name { + let file = std::fs::OpenOptions::new() + .create_new(true) + .write(true) + .open(file_name) + .unwrap(); + + let mut pcap_writer = pcap_file::pcapng::PcapNgWriter::new(file).unwrap(); + + // Write headers + pcap_writer.write_block( + &pcap_file::pcapng::blocks::interface_description::InterfaceDescriptionBlock { + linktype: pcap_file::DataLink::ETHERNET, + snaplen: 0xFFFF, + options: vec![], + } + .into_block(), + ); + + let writer = Box::new(Pcap::new(pcap_writer)); + attach(writer) + } else { + // Do nothing + } +} + +/// Attaches a writer to the capture module. +/// +/// # Errors +/// Returns an `io::Error` if the writer is already set. +fn attach(writer: Box) { + crate::socket::capture::set_writer(writer); +} diff --git a/crates/lib/src/capture/packet.rs b/crates/lib/src/capture/packet.rs new file mode 100644 index 0000000..4f726fc --- /dev/null +++ b/crates/lib/src/capture/packet.rs @@ -0,0 +1,254 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +/// Size of a standard network packet. +pub(crate) const PACKET_SIZE: usize = 5012; +/// Size of an Ethernet header. +pub(crate) const HEADER_SIZE_ETHERNET: usize = 14; +/// Size of an IPv4 header. +pub(crate) const HEADER_SIZE_IP4: usize = 20; +/// Size of an IPv6 header. +pub(crate) const HEADER_SIZE_IP6: usize = 40; +/// Size of a UDP header. +pub(crate) const HEADER_SIZE_UDP: usize = 4; + +/// Represents the direction of a network packet. +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) enum Direction { + /// Packet is outgoing (sent by us). + Send, + /// Packet is incoming (received by us). + Receive, +} + +/// Defines the protocol of a network packet. +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) enum Protocol { + /// Transmission Control Protocol. + TCP, + /// User Datagram Protocol. + UDP, +} + +/// Trait for handling different types of IP addresses (IPv4, IPv6). +pub trait IpAddress: Sized { + /// Creates an instance from a standard `IpAddr`, returning `None` if the types are incompatible. + fn from_std(ip: IpAddr) -> Option; +} + +/// Represents a captured network packet with metadata. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct CapturePacket<'a> { + /// Direction of the packet (Send/Receive). + pub(crate) direction: Direction, + /// Protocol of the packet (TCP/UDP). + pub(crate) protocol: Protocol, + /// Remote socket address. + pub(crate) remote_address: &'a SocketAddr, + /// Local socket address. + pub(crate) local_address: &'a SocketAddr, +} + +impl CapturePacket<'_> { + /// Retrieves the local and remote ports based on the packet's direction. + /// + /// Returns: + /// - (u16, u16): Tuple of (source port, destination port). + pub(crate) const fn ports_by_direction(&self) -> (u16, u16) { + let (local, remote) = (self.local_address.port(), self.remote_address.port()); + self.direction.order(local, remote) + } + + /// Retrieves the local and remote IP addresses. + /// + /// Returns: + /// - (IpAddr, IpAddr): Tuple of (local IP, remote IP). + pub(crate) fn ip_addr(&self) -> (IpAddr, IpAddr) { + let (local, remote) = (self.local_address.ip(), self.remote_address.ip()); + (local, remote) + } + + /// Retrieves IP addresses based on the packet's direction. + /// + /// Returns: + /// - (IpAddr, IpAddr): Tuple of (source IP, destination IP). + pub(crate) fn ip_addr_by_direction(&self) -> (IpAddr, IpAddr) { + let (local, remote) = self.ip_addr(); + self.direction.order(local, remote) + } + + /// Retrieves IP addresses of a specific type (IPv4 or IPv6) based on the packet's direction. + /// + /// Panics if the IP type of the addresses does not match the requested type. + /// + /// Returns: + /// - (T, T): Tuple of (source IP, destination IP) of the specified type in order. + pub(crate) fn ipvt_by_direction(&self) -> (T, T) { + let (local, remote) = ( + T::from_std(self.local_address.ip()).expect("Incorrect IP type for local address"), + T::from_std(self.remote_address.ip()).expect("Incorrect IP type for remote address"), + ); + + self.direction.order(local, remote) + } +} + +impl Direction { + /// Orders two elements (source and destination) based on the packet's direction. + /// + /// Returns: + /// - (T, T): Ordered tuple (source, destination). + pub(crate) const fn order(&self, source: T, remote: T) -> (T, T) { + match self { + Direction::Send => (source, remote), + Direction::Receive => (remote, source), + } + } +} + +/// Implements the `IpAddress` trait for `Ipv4Addr`. +impl IpAddress for Ipv4Addr { + /// Creates an `Ipv4Addr` from a standard `IpAddr`, if it's IPv4. + fn from_std(ip: IpAddr) -> Option { + match ip { + IpAddr::V4(ipv4) => Some(ipv4), + _ => None, + } + } +} + +/// Implements the `IpAddress` trait for `Ipv6Addr`. +impl IpAddress for Ipv6Addr { + /// Creates an `Ipv6Addr` from a standard `IpAddr`, if it's IPv6. + fn from_std(ip: IpAddr) -> Option { + match ip { + IpAddr::V6(ipv6) => Some(ipv6), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + // Helper function to create a SocketAddr from a string + fn socket_addr(addr: &str) -> SocketAddr { + SocketAddr::from_str(addr).unwrap() + } + + #[test] + fn test_ports_by_direction() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + assert_eq!(packet_send.ports_by_direction(), (8080, 80)); + assert_eq!(packet_receive.ports_by_direction(), (80, 8080)); + } + + #[test] + fn test_ip_addr_by_direction_ipv4() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("10.0.0.1:3000"), + remote_address: &socket_addr("10.0.0.2:3001"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::UDP, + local_address: &socket_addr("10.0.0.1:3000"), + remote_address: &socket_addr("10.0.0.2:3001"), + }; + + assert_eq!( + packet_send.ip_addr_by_direction(), + ( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)) + ) + ); + assert_eq!( + packet_receive.ip_addr_by_direction(), + ( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)) + ) + ); + } + + #[test] + fn test_ip_addr_by_direction_ipv6() { + let packet_send = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("[::1]:3000"), + remote_address: &socket_addr("[::2]:3001"), + }; + + let packet_receive = CapturePacket { + direction: Direction::Receive, + protocol: Protocol::UDP, + local_address: &socket_addr("[::1]:3000"), + remote_address: &socket_addr("[::2]:3001"), + }; + + assert_eq!( + packet_send.ip_addr_by_direction(), + ( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)) + ) + ); + assert_eq!( + packet_receive.ip_addr_by_direction(), + ( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)) + ) + ); + } + + #[test] + fn test_ip_by_direction_type_specific() { + let packet = CapturePacket { + direction: Direction::Send, + protocol: Protocol::TCP, + local_address: &socket_addr("127.0.0.1:8080"), + remote_address: &socket_addr("192.168.1.1:80"), + }; + + let ipv4_result: Result<(Ipv4Addr, Ipv4Addr), _> = + std::panic::catch_unwind(|| packet.ipvt_by_direction::()); + assert!(ipv4_result.is_ok()); + + let ipv6_result: Result<(Ipv6Addr, Ipv6Addr), _> = + std::panic::catch_unwind(|| packet.ipvt_by_direction::()); + assert!(ipv6_result.is_err()); + } + + #[test] + #[should_panic(expected = "Local and remote IP addresses must be of the same version")] + fn test_mismatched_ip_version_panic() { + let packet = CapturePacket { + direction: Direction::Send, + protocol: Protocol::UDP, + local_address: &socket_addr("127.0.0.1:8080"), // IPv4 + remote_address: &socket_addr("[::1]:80"), // IPv6 + }; + + packet.ip_addr_by_direction(); + } +} diff --git a/crates/lib/src/capture/pcap.rs b/crates/lib/src/capture/pcap.rs new file mode 100644 index 0000000..054c8e3 --- /dev/null +++ b/crates/lib/src/capture/pcap.rs @@ -0,0 +1,311 @@ +use pcap_file::pcapng::{ + blocks::enhanced_packet::{EnhancedPacketBlock, EnhancedPacketOption}, + PcapNgBlock, PcapNgWriter, +}; +use pnet_packet::{ + ethernet::{EtherType, EtherTypes, MutableEthernetPacket}, + ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, + ipv4::MutableIpv4Packet, + ipv6::MutableIpv6Packet, + tcp::{MutableTcpPacket, TcpFlags}, + udp::MutableUdpPacket, + PacketSize, +}; +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::Instant, +}; + +use super::packet::{ + CapturePacket, Direction, Protocol, HEADER_SIZE_ETHERNET, HEADER_SIZE_IP4, HEADER_SIZE_IP6, HEADER_SIZE_UDP, + PACKET_SIZE, +}; + +const DEFAULT_TTL: u8 = 64; +const TCP_WINDOW_SIZE: u16 = 43440; +const BUFFER_SIZE: usize = PACKET_SIZE + - (if HEADER_SIZE_IP4 > HEADER_SIZE_IP6 { + HEADER_SIZE_IP4 + } else { + HEADER_SIZE_IP6 + }) + - HEADER_SIZE_ETHERNET; + +pub(crate) struct Pcap { + writer: PcapNgWriter, + pub(crate) state: State, + buffer: Vec, +} + +pub(crate) struct State { + pub(crate) start_time: Instant, + pub(crate) send_seq: u32, + pub(crate) rec_seq: u32, + pub(crate) has_sent_handshake: bool, + pub(crate) has_sent_fin: bool, + pub(crate) stream_count: u32, +} + +impl Pcap { + pub fn new(writer: PcapNgWriter) -> Self { + Self { + writer, + state: State::default(), + buffer: vec![0; BUFFER_SIZE], + } + } + + pub fn write_transport_packet(&mut self, info: &CapturePacket, payload: &[u8]) -> Result<(), std::io::Error> { + let (source_port, dest_port) = info.ports_by_direction(); + + match info.protocol { + Protocol::TCP => self.handle_tcp(info, payload, source_port, dest_port)?, + Protocol::UDP => self.handle_udp(info, payload, source_port, dest_port)?, + } + + Ok(()) + } + + fn handle_tcp( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result<(), std::io::Error> { + let buf_size = self.setup_tcp_packet(info, payload, source_port, dest_port)?; + self.write_transport_payload( + info, + IpNextHeaderProtocols::Tcp, + &self.buffer[..buf_size + payload.len()], + vec![], + ); + + Ok(()) + } + + fn setup_tcp_packet( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result { + let mut tcp = MutableTcpPacket::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create TCP packet"))?; + tcp.set_source(source_port); + tcp.set_destination(dest_port); + tcp.set_payload(payload); + tcp.set_data_offset(5); + tcp.set_window(TCP_WINDOW_SIZE); + + // Set sequence and acknowledgement numbers + match info.direction { + Direction::Send => { + tcp.set_sequence(self.state.send_seq); + tcp.set_acknowledgement(self.state.rec_seq); + self.state.send_seq = self.state.send_seq.wrapping_add(payload.len() as u32); + } + Direction::Receive => { + tcp.set_sequence(self.state.rec_seq); + tcp.set_acknowledgement(self.state.send_seq); + self.state.rec_seq = self.state.rec_seq.wrapping_add(payload.len() as u32); + } + } + tcp.set_flags(TcpFlags::PSH | TcpFlags::ACK); + + Ok(tcp.packet_size()) + } + + pub fn write_tcp_handshake(&mut self, info: &CapturePacket) { + // Initialize sequence numbers for demonstration purposes + self.state.send_seq = 500; + self.state.rec_seq = 1000; + + // Common setup for TCP handshake packets + let mut tcp_handshake_packet = + |info: &CapturePacket, direction: Direction, flags: u8| -> Result<(), std::io::Error> { + let (source_port, dest_port) = info.ports_by_direction(); + let adjusted_info = CapturePacket { + direction, + ..info.clone() + }; + self.setup_tcp_packet(&adjusted_info, &[], source_port, dest_port)?; + Ok(self.write_transport_payload( + &adjusted_info, + IpNextHeaderProtocols::Tcp, + &self.buffer, + vec![EnhancedPacketOption::Comment( + format!( + "Generated TCP {}", + match flags { + TcpFlags::SYN => "SYN", + TcpFlags::SYN | TcpFlags::ACK => "SYN-ACK", + TcpFlags::ACK => "ACK", + } + ) + .into(), + )], + )) + }; + + // Send SYN + tcp_handshake_packet(info, Direction::Send, TcpFlags::SYN); + + // Send SYN-ACK + self.state.send_seq = self.state.send_seq.wrapping_add(1); // Update sequence number after SYN + tcp_handshake_packet(info, Direction::Receive, TcpFlags::SYN | TcpFlags::ACK); + + // Send ACK + self.state.rec_seq = self.state.rec_seq.wrapping_add(1); // Update sequence number after SYN-ACK + tcp_handshake_packet(info, Direction::Send, TcpFlags::ACK); + } + + fn handle_udp( + &mut self, + info: &CapturePacket, + payload: &[u8], + source_port: u16, + dest_port: u16, + ) -> Result<(), std::io::Error> { + let mut udp = MutableUdpPacket::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create UDP packet"))?; + udp.set_source(source_port); + udp.set_destination(dest_port); + udp.set_length((payload.len() + HEADER_SIZE_UDP) as u16); + udp.set_payload(payload); + + let buf_size = udp.packet_size(); + self.write_transport_payload( + info, + IpNextHeaderProtocols::Udp, + &self.buffer[..buf_size + payload.len()], + vec![], + ); + + Ok(()) + } + + fn write_transport_payload( + &mut self, + info: &CapturePacket, + protocol: IpNextHeaderProtocol, + payload: &[u8], + options: Vec, + ) { + let network_packet_size = self.encode_ip_packet(info, protocol, payload).unwrap().0; + let ethertype = self.encode_ip_packet(info, protocol, payload).unwrap().1; + let ethernet_packet_size = self.encode_ethernet_packet(info, ethertype, &self.buffer[..network_packet_size]).unwrap(); + + let enhanced_packet_block = EnhancedPacketBlock { + original_len: ethernet_packet_size as u32, + data: self.buffer[..ethernet_packet_size].to_vec().into(), + interface_id: 0, + timestamp: self.state.start_time.elapsed(), + options, + }; + + self.writer.write_block(&enhanced_packet_block.into_block()); + } + + fn encode_ethernet_packet( + &mut self, + info: &CapturePacket, + ethertype: EtherType, + payload: &[u8], + ) -> Result { + let mut ethernet_packet = MutableEthernetPacket::new(&mut self.buffer).ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to create Ethernet packet", + ) + })?; + + ethernet_packet.set_ethertype(ethertype); + ethernet_packet.set_payload(payload); + + Ok(ethernet_packet.packet_size()) + } + + fn encode_ip_packet( + &mut self, + info: &CapturePacket, + protocol: IpNextHeaderProtocol, + payload: &[u8], + ) -> Result<(usize, EtherType), std::io::Error> { + match info.ip_addr() { + (IpAddr::V4(_), IpAddr::V4(_)) => { + let (source, destination) = info.ipvt_by_direction(); + + let mut ip_packet = MutableIpv4Packet::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create IPv4 packet"))?; + self.set_ipv4_packet_fields(&mut ip_packet, source, destination, payload, protocol); + ip_packet.set_checksum(pnet_packet::ipv4::checksum(&ip_packet.to_immutable())); + + Ok((ip_packet.packet_size(), EtherTypes::Ipv4)) + } + (IpAddr::V6(_), IpAddr::V6(_)) => { + let (source, destination) = info.ipvt_by_direction(); + + let mut ip_packet = MutableIpv6Packet::new(&mut self.buffer) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Failed to create IPv6 packet"))?; + self.set_ipv6_packet_fields(&mut ip_packet, source, destination, payload, protocol); + + Ok((ip_packet.packet_size(), EtherTypes::Ipv6)) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unsupported or mismatched IP address types", + )), + } + } + + fn set_ipv4_packet_fields( + &mut self, + ip_packet: &mut MutableIpv4Packet, + source: Ipv4Addr, + destination: Ipv4Addr, + payload: &[u8], + protocol: IpNextHeaderProtocol, + ) { + ip_packet.set_version(4); + ip_packet.set_header_length(5); // No options + ip_packet.set_total_length((payload.len() + HEADER_SIZE_IP4) as u16); + ip_packet.set_next_level_protocol(protocol); + ip_packet.set_source(source); + ip_packet.set_destination(destination); + ip_packet.set_ttl(DEFAULT_TTL); + ip_packet.set_payload(payload); + } + + fn set_ipv6_packet_fields( + &mut self, + ip_packet: &mut MutableIpv6Packet, + source: Ipv6Addr, + destination: Ipv6Addr, + payload: &[u8], + protocol: IpNextHeaderProtocol, + ) { + ip_packet.set_version(6); + ip_packet.set_payload_length(payload.len() as u16); + ip_packet.set_next_header(protocol); + ip_packet.set_source(source); + ip_packet.set_destination(destination); + ip_packet.set_hop_limit(DEFAULT_TTL); + ip_packet.set_payload(payload); + } +} + +impl Default for State { + fn default() -> Self { + Self { + start_time: Instant::now(), + send_seq: 0, + rec_seq: 0, + has_sent_handshake: false, + has_sent_fin: false, + stream_count: 0, + } + } +} diff --git a/crates/lib/src/capture/writer.rs b/crates/lib/src/capture/writer.rs new file mode 100644 index 0000000..012b700 --- /dev/null +++ b/crates/lib/src/capture/writer.rs @@ -0,0 +1,41 @@ +use std::io::Write; + +use crate::{ + capture::packet::{CapturePacket, Protocol}, + GDResult, +}; + +use super::pcap::Pcap; + +use lazy_static::lazy_static; +use std::sync::Mutex; + +lazy_static! { + pub(crate) static ref CAPTURE_WRITER: Mutex>> = Mutex::new(None); +} + +pub trait Writer { + fn write(&mut self, packet: &CapturePacket, data: &[u8]) -> crate::GDResult<()>; + fn new_connect(&mut self, packet: &CapturePacket) -> crate::GDResult<()>; +} + +impl Writer for Pcap { + fn write(&mut self, info: &CapturePacket, data: &[u8]) -> GDResult<()> { + self.write_transport_packet(info, data); + + Ok(()) + } + + fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()> { + match packet.protocol { + Protocol::TCP => { + self.write_tcp_handshake(packet); + } + Protocol::UDP => {} + } + + self.state.stream_count = self.state.stream_count.wrapping_add(1); + + Ok(()) + } +} diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index b17777c..678b0a7 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -45,13 +45,11 @@ mod buffer; mod socket; mod utils; +#[cfg(feature = "packet_capture")] +pub(crate) mod capture; + pub use errors::*; #[cfg(feature = "games")] pub use games::*; -#[cfg(feature = "games")] -pub use query::*; #[cfg(feature = "services")] pub use services::*; - -// Re-export types needed to call games::query::query in the root -pub use protocols::types::{ExtraRequestSettings, TimeoutSettings}; diff --git a/crates/lib/src/socket.rs b/crates/lib/src/socket.rs index 223c94c..f0fe3d3 100644 --- a/crates/lib/src/socket.rs +++ b/crates/lib/src/socket.rs @@ -12,32 +12,75 @@ use std::{ const DEFAULT_PACKET_SIZE: usize = 1024; +/// A trait defining the basic functionalities of a network socket. pub trait Socket { - /// Create a new socket and connect to the remote address (if required). + /// Create a new socket and connect to the remote address. /// - /// Calls [Self::apply_timeout] with the given timeout settings. + /// # Arguments + /// * `address` - The address to connect the socket to. + /// * `timeout_settings` - Optional timeout settings for the socket. + /// + /// # Returns + /// A result containing the socket instance or an error. fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult - where Self: Sized; + where + Self: Sized; + /// Apply read and write timeouts to the socket. + /// + /// # Arguments + /// * `timeout_settings` - Optional timeout settings to apply. + /// + /// # Returns + /// A result indicating success or error in applying timeouts. fn apply_timeout(&self, timeout_settings: &Option) -> GDResult<()>; + /// Send data over the socket. + /// + /// # Arguments + /// * `data` - Data to be sent. + /// + /// # Returns + /// A result indicating success or error in sending data. fn send(&mut self, data: &[u8]) -> GDResult<()>; + + /// Receive data from the socket. + /// + /// # Arguments + /// * `size` - Optional size of data to receive. + /// + /// # Returns + /// A result containing received data or an error. fn receive(&mut self, size: Option) -> GDResult>; + /// Get the remote port of the socket. + /// + /// # Returns + /// The port number. fn port(&self) -> u16; + + /// Get the local SocketAddr. + /// + /// # Returns + /// The local SocketAddr. + fn local_addr(&self) -> std::io::Result; } -pub struct TcpSocket { +/// Implementation of a TCP socket. +pub struct TcpSocketImpl { + /// The underlying TCP socket stream. socket: net::TcpStream, + /// The address of the remote host. address: SocketAddr, } -impl Socket for TcpSocket { +impl Socket for TcpSocketImpl { fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult { - let socket = TimeoutSettings::get_connect_or_default(timeout_settings).map_or_else( - || net::TcpStream::connect(address), - |timeout| net::TcpStream::connect_timeout(address, timeout), - ); + let socket = if let Some(timeout) = TimeoutSettings::get_connect_or_default(timeout_settings) { + net::TcpStream::connect_timeout(address, timeout) + } else { + net::TcpStream::connect(address) + }; let socket = Self { socket: socket.map_err(|e| SocketConnect.context(e))?, @@ -71,15 +114,23 @@ impl Socket for TcpSocket { Ok(buf) } - fn port(&self) -> u16 { self.address.port() } + fn port(&self) -> u16 { + self.address.port() + } + fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() + } } -pub struct UdpSocket { +/// Implementation of a UDP socket. +pub struct UdpSocketImpl { + /// The underlying UDP socket. socket: net::UdpSocket, + /// The address of the remote host. address: SocketAddr, } -impl Socket for UdpSocket { +impl Socket for UdpSocketImpl { fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult { let socket = net::UdpSocket::bind("0.0.0.0:0").map_err(|e| SocketBind.context(e))?; @@ -116,12 +167,234 @@ impl Socket for UdpSocket { .recv_from(&mut buf) .map_err(|e| PacketReceive.context(e))?; - Ok(buf[.. number_of_bytes_received].to_vec()) + Ok(buf[..number_of_bytes_received].to_vec()) } - fn port(&self) -> u16 { self.address.port() } + fn port(&self) -> u16 { + self.address.port() + } + fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() + } } +/// Things used for capturing packets. +#[cfg(feature = "packet_capture")] +pub mod capture { + use std::{marker::PhantomData, net::SocketAddr}; + + use super::{Socket, TcpSocketImpl, UdpSocketImpl}; + + use crate::{ + capture::{ + packet::CapturePacket, + packet::{Direction, Protocol}, + writer::{Writer, CAPTURE_WRITER}, + }, + protocols::types::TimeoutSettings, + GDResult, + }; + + /// Sets a global capture writer for handling all packet data. + /// + /// # Panics + /// Panics if a capture writer is already set. + /// + /// # Arguments + /// * `writer` - A boxed writer that implements the `Writer` trait. + pub(crate) fn set_writer(writer: Box) { + let mut lock = CAPTURE_WRITER.lock().unwrap(); + + if lock.is_some() { + panic!("Capture writer already set"); + } + + *lock = Some(writer); + } + + /// A trait representing a provider of a network protocol. + pub trait ProtocolProvider { + /// Returns the protocol used by the provider. + fn protocol() -> Protocol; + } + + /// Represents the TCP protocol provider. + pub struct ProtocolTCP; + impl ProtocolProvider for ProtocolTCP { + fn protocol() -> Protocol { + Protocol::TCP + } + } + + /// Represents the UDP protocol provider. + pub struct ProtocolUDP; + impl ProtocolProvider for ProtocolUDP { + fn protocol() -> Protocol { + Protocol::UDP + } + } + + /// A socket wrapper that allows capturing packets. + /// + /// # Type parameters + /// * `I` - The inner socket type. + /// * `P` - The protocol provider. + #[derive(Clone, Debug)] + pub struct WrappedCaptureSocket { + inner: I, + remote_address: SocketAddr, + _protocol: PhantomData

, + } + + impl Socket for WrappedCaptureSocket { + /// Creates a new wrapped socket for capturing packets. + /// + /// Initializes a new socket of type `I`, wrapping it to enable packet capturing. + /// Capturing is protocol-specific, as indicated by the `ProtocolProvider`. + /// + /// # Arguments + /// * `address` - The address to connect the socket to. + /// * `timeout_settings` - Optional timeout settings for the socket. + /// + /// # Returns + /// A `GDResult` containing either the wrapped socket or an error. + fn new(address: &SocketAddr, timeout_settings: &Option) -> GDResult + where + Self: Sized, + { + let v = Self { + inner: I::new(address, timeout_settings)?, + remote_address: *address, + _protocol: PhantomData, + }; + + let info = CapturePacket { + direction: Direction::Send, + protocol: P::protocol(), + remote_address: address, + local_address: &v.local_addr().unwrap(), + }; + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.new_connect(&info)?; + } + + Ok(v) + } + + /// Sends data over the socket and captures the packet. + /// + /// The method sends data using the inner socket and captures the sent packet + /// if a capture writer is set. + /// + /// # Arguments + /// * `data` - Data to be sent. + /// + /// # Returns + /// A result indicating success or error in sending data. + fn send(&mut self, data: &[u8]) -> crate::GDResult<()> { + let info = CapturePacket { + direction: Direction::Send, + protocol: P::protocol(), + remote_address: &self.remote_address, + local_address: &self.local_addr().unwrap(), + }; + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.write(&info, data)?; + } + + self.inner.send(data) + } + + /// Receives data from the socket and captures the packet. + /// + /// The method receives data using the inner socket and captures the incoming packet + /// if a capture writer is set. + /// + /// # Arguments + /// * `size` - Optional size of data to receive. + /// + /// # Returns + /// A result containing received data or an error. + fn receive(&mut self, size: Option) -> crate::GDResult> { + let data = self.inner.receive(size)?; + let info = CapturePacket { + direction: Direction::Receive, + protocol: P::protocol(), + remote_address: &self.remote_address, + local_address: &self.local_addr().unwrap(), + }; + + if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() { + writer.write(&info, &data)?; + } + + Ok(data) + } + + /// Applies timeout settings to the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Arguments + /// * `timeout_settings` - Optional timeout settings to apply. + /// + /// # Returns + /// A result indicating success or error in applying timeouts. + fn apply_timeout( + &self, + timeout_settings: &Option, + ) -> crate::GDResult<()> { + self.inner.apply_timeout(timeout_settings) + } + + /// Returns the remote port of the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Returns + /// The remote port number. + fn port(&self) -> u16 { + self.inner.port() + } + + /// Returns the local SocketAddr of the wrapped socket. + /// + /// Delegates the operation to the inner socket implementation. + /// + /// # Returns + /// The local SocketAddr. + fn local_addr(&self) -> std::io::Result { + self.inner.local_addr() + } + } + + /// A specialized `WrappedCaptureSocket` for UDP, using `UdpSocketImpl` as the inner socket + /// and `ProtocolUDP` as the protocol provider. + /// + /// This type captures and processes UDP packets, wrapping around standard UDP socket + /// functionalities with additional packet capture capabilities. + pub type CapturedUdpSocket = WrappedCaptureSocket; + + /// A specialized `WrappedCaptureSocket` for TCP, using `TcpSocketImpl` as the inner socket + /// and `ProtocolTCP` as the protocol provider. + /// + /// This type captures and processes TCP packets, wrapping around standard TCP socket + /// functionalities with additional packet capture capabilities. + pub type CapturedTcpSocket = WrappedCaptureSocket; +} + +#[cfg(not(feature = "packet_capture"))] +pub type UdpSocket = UdpSocketImpl; +#[cfg(not(feature = "packet_capture"))] +pub type TcpSocket = TcpSocketImpl; + +#[cfg(feature = "packet_capture")] +pub type UdpSocket = capture::CapturedUdpSocket; +#[cfg(feature = "packet_capture")] +pub type TcpSocket = capture::CapturedTcpSocket; + #[cfg(test)] mod tests { use std::thread;