feat: Add packet capture functionality and many more CLI improvements (#182)

This commit is contained in:
Cain 2024-02-07 22:31:31 +00:00 committed by GitHub
commit e86e80522b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1445 additions and 82 deletions

View file

@ -0,0 +1,39 @@
pub(crate) mod packet;
mod pcap;
pub(crate) mod socket;
pub(crate) mod writer;
use self::{pcap::Pcap, writer::Writer};
use pcap_file::pcapng::{blocks::interface_description::InterfaceDescriptionBlock, PcapNgBlock, PcapNgWriter};
use std::path::PathBuf;
pub fn setup_capture(file_path: Option<PathBuf>) {
if let Some(file_path) = file_path {
let file = std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(file_path.with_extension("pcap"))
.unwrap();
let mut pcap_writer = PcapNgWriter::new(file).unwrap();
// Write headers
let _ = pcap_writer.write_block(
&InterfaceDescriptionBlock {
linktype: pcap_file::DataLink::ETHERNET,
snaplen: 0xFFFF,
options: vec![],
}
.into_block(),
);
let writer = Box::new(Pcap::new(pcap_writer));
attach(writer)
}
}
/// Attaches a writer to the capture module.
///
/// # Errors
/// Returns an Error if the writer is already set.
fn attach(writer: Box<dyn Writer + Send + Sync>) { crate::capture::socket::set_writer(writer); }

View file

@ -0,0 +1,203 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
/// Size of a standard network packet.
pub(crate) const PACKET_SIZE: usize = 5012;
/// Size of an Ethernet header.
pub(crate) const HEADER_SIZE_ETHERNET: usize = 14;
/// Size of an IPv4 header.
pub(crate) const HEADER_SIZE_IP4: usize = 20;
/// Size of an IPv6 header.
pub(crate) const HEADER_SIZE_IP6: usize = 40;
/// Size of a UDP header.
pub(crate) const HEADER_SIZE_UDP: usize = 4;
/// Represents the direction of a network packet.
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) enum Direction {
/// Packet is outgoing (sent by us).
Send,
/// Packet is incoming (received by us).
Receive,
}
/// Defines the protocol of a network packet.
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) enum Protocol {
/// Transmission Control Protocol.
Tcp,
/// User Datagram Protocol.
Udp,
}
/// Trait for handling different types of IP addresses (IPv4, IPv6).
pub(crate) trait IpAddress: Sized {
/// Creates an instance from a standard `IpAddr`, returning `None` if the
/// types are incompatible.
fn from_std(ip: IpAddr) -> Option<Self>;
}
/// Represents a captured network packet with metadata.
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct CapturePacket<'a> {
/// Direction of the packet (Send/Receive).
pub(crate) direction: Direction,
/// Protocol of the packet (Tcp/UDP).
pub(crate) protocol: Protocol,
/// Remote socket address.
pub(crate) remote_address: &'a SocketAddr,
/// Local socket address.
pub(crate) local_address: &'a SocketAddr,
}
impl CapturePacket<'_> {
/// Retrieves the local and remote ports based on the packet's direction.
///
/// Returns:
/// - (u16, u16): Tuple of (source port, destination port).
pub(super) fn ports_by_direction(&self) -> (u16, u16) {
let (local, remote) = (self.local_address.port(), self.remote_address.port());
self.direction.order(local, remote)
}
/// Retrieves the local and remote IP addresses.
///
/// Returns:
/// - (IpAddr, IpAddr): Tuple of (local IP, remote IP).
pub(super) fn ip_addr(&self) -> (IpAddr, IpAddr) {
let (local, remote) = (self.local_address.ip(), self.remote_address.ip());
(local, remote)
}
/// Retrieves IP addresses of a specific type (IPv4 or IPv6) based on the
/// packet's direction.
///
/// Panics if the IP type of the addresses does not match the requested
/// type.
///
/// Returns:
/// - (T, T): Tuple of (source IP, destination IP) of the specified type in
/// order.
pub(super) fn ipvt_by_direction<T: IpAddress>(&self) -> (T, T) {
let (local, remote) = (
T::from_std(self.local_address.ip()).expect("Incorrect IP type for local address"),
T::from_std(self.remote_address.ip()).expect("Incorrect IP type for remote address"),
);
self.direction.order(local, remote)
}
}
impl Direction {
/// Orders two elements (source and destination) based on the packet's
/// direction.
///
/// Returns:
/// - (T, T): Ordered tuple (source, destination).
pub(self) const fn order<T>(&self, source: T, remote: T) -> (T, T) {
match self {
Direction::Send => (source, remote),
Direction::Receive => (remote, source),
}
}
}
/// Implements the `IpAddress` trait for `Ipv4Addr`.
impl IpAddress for Ipv4Addr {
/// Creates an `Ipv4Addr` from a standard `IpAddr`, if it's IPv4.
fn from_std(ip: IpAddr) -> Option<Self> {
match ip {
IpAddr::V4(ipv4) => Some(ipv4),
_ => None,
}
}
}
/// Implements the `IpAddress` trait for `Ipv6Addr`.
impl IpAddress for Ipv6Addr {
/// Creates an `Ipv6Addr` from a standard `IpAddr`, if it's IPv6.
fn from_std(ip: IpAddr) -> Option<Self> {
match ip {
IpAddr::V6(ipv6) => Some(ipv6),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// Helper function to create a SocketAddr from a string
fn socket_addr(addr: &str) -> SocketAddr { SocketAddr::from_str(addr).unwrap() }
#[test]
fn test_ports_by_direction() {
let packet_send = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let packet_receive = CapturePacket {
direction: Direction::Receive,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
assert_eq!(packet_send.ports_by_direction(), (8080, 80));
assert_eq!(packet_receive.ports_by_direction(), (80, 8080));
}
#[test]
fn test_ip_addr() {
let packet_send = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let packet_receive = CapturePacket {
direction: Direction::Receive,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
assert_eq!(
packet_send.ip_addr(),
(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
)
);
assert_eq!(
packet_receive.ip_addr(),
(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
)
);
}
#[test]
fn test_ip_by_direction_type_specific() {
let packet = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let ipv4_result: Result<(Ipv4Addr, Ipv4Addr), _> =
std::panic::catch_unwind(|| packet.ipvt_by_direction::<Ipv4Addr>());
assert!(ipv4_result.is_ok());
let ipv6_result: Result<(Ipv6Addr, Ipv6Addr), _> =
std::panic::catch_unwind(|| packet.ipvt_by_direction::<Ipv6Addr>());
assert!(ipv6_result.is_err());
}
}

View file

@ -0,0 +1,383 @@
use pcap_file::pcapng::{blocks::enhanced_packet::EnhancedPacketOption, PcapNgBlock, PcapNgWriter};
use pnet_packet::{
ethernet::{EtherType, MutableEthernetPacket},
ip::{IpNextHeaderProtocol, IpNextHeaderProtocols},
ipv4::MutableIpv4Packet,
ipv6::MutableIpv6Packet,
tcp::{MutableTcpPacket, TcpFlags},
udp::MutableUdpPacket,
PacketSize,
};
use std::{io::Write, net::IpAddr, time::Instant};
use super::packet::{
CapturePacket,
Direction,
Protocol,
HEADER_SIZE_ETHERNET,
HEADER_SIZE_IP4,
HEADER_SIZE_IP6,
HEADER_SIZE_UDP,
PACKET_SIZE,
};
const BUFFER_SIZE: usize = PACKET_SIZE - HEADER_SIZE_IP6 - HEADER_SIZE_ETHERNET;
pub(crate) struct Pcap<W: Write> {
writer: PcapNgWriter<W>,
pub(crate) state: State,
}
pub(crate) struct State {
pub(crate) start_time: Instant,
pub(crate) send_seq: u32,
pub(crate) rec_seq: u32,
pub(crate) has_sent_handshake: bool,
pub(crate) stream_count: u32,
}
impl<W: Write> Pcap<W> {
pub(crate) fn new(writer: PcapNgWriter<W>) -> Self {
Self {
writer,
state: State::default(),
}
}
pub(crate) fn write_transport_packet(&mut self, info: &CapturePacket, payload: &[u8]) {
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
let (source_port, dest_port) = info.ports_by_direction();
match info.protocol {
Protocol::Tcp => {
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_payload(payload);
tcp.set_data_offset(5);
tcp.set_window(43440);
match info.direction {
Direction::Send => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
self.state.send_seq = self.state.send_seq.wrapping_add(payload.len() as u32);
}
Direction::Receive => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
self.state.rec_seq = self.state.rec_seq.wrapping_add(payload.len() as u32);
}
}
tcp.set_flags(TcpFlags::PSH | TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size + payload.len()],
vec![],
);
let mut info = info.clone();
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(dest_port);
tcp.set_destination(source_port);
tcp.set_data_offset(5);
tcp.set_window(43440);
match &info.direction {
Direction::Send => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
info.direction = Direction::Receive;
}
Direction::Receive => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
info.direction = Direction::Send;
}
}
tcp.set_flags(TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
vec![EnhancedPacketOption::Comment("Generated TCP ACK".into())],
);
}
Protocol::Udp => {
let buf_size = {
let mut udp = MutableUdpPacket::new(buf).unwrap();
udp.set_source(source_port);
udp.set_destination(dest_port);
udp.set_length((payload.len() + HEADER_SIZE_UDP) as u16);
udp.set_payload(payload);
udp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Udp,
&buf[.. buf_size + payload.len()],
vec![],
);
}
}
}
/// Encode a network layer (IP) packet with a payload.
fn encode_ip_packet(
&self,
buf: &mut [u8],
info: &CapturePacket,
protocol: IpNextHeaderProtocol,
payload: &[u8],
) -> (usize, EtherType) {
match info.ip_addr() {
(IpAddr::V4(_), IpAddr::V4(_)) => {
let (source, destination) = info.ipvt_by_direction();
let header_size = HEADER_SIZE_IP4 + (32 / 8);
let mut ip = MutableIpv4Packet::new(buf).unwrap();
ip.set_version(4);
ip.set_total_length((payload.len() + header_size) as u16);
ip.set_next_level_protocol(protocol);
// https://en.wikipedia.org/wiki/Internet_Protocol_version_4#Total_Length
ip.set_header_length((header_size / 4) as u8);
ip.set_source(source);
ip.set_destination(destination);
ip.set_payload(payload);
ip.set_ttl(64);
ip.set_flags(pnet_packet::ipv4::Ipv4Flags::DontFragment);
let mut options_writer =
pnet_packet::ipv4::MutableIpv4OptionPacket::new(ip.get_options_raw_mut()).unwrap();
options_writer.set_copied(1);
options_writer.set_class(0);
options_writer.set_number(pnet_packet::ipv4::Ipv4OptionNumbers::SID);
options_writer.set_length(&[4]);
options_writer.set_data(&(self.state.stream_count as u16).to_be_bytes());
ip.set_checksum(pnet_packet::ipv4::checksum(&ip.to_immutable()));
(ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv4)
}
(IpAddr::V6(_), IpAddr::V6(_)) => {
let (source, destination) = info.ipvt_by_direction();
let mut ip = MutableIpv6Packet::new(buf).unwrap();
ip.set_version(6);
ip.set_payload_length(payload.len() as u16);
ip.set_next_header(protocol);
ip.set_source(source);
ip.set_destination(destination);
ip.set_hop_limit(64);
ip.set_payload(payload);
ip.set_flow_label(self.state.stream_count);
(ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv6)
}
_ => unreachable!(),
}
}
/// Encode a physical layer (ethernet) packet with a payload.
fn encode_ethernet_packet(
&self,
buf: &mut [u8],
ethertype: pnet_packet::ethernet::EtherType,
payload: &[u8],
) -> usize {
let mut ethernet = MutableEthernetPacket::new(buf).unwrap();
ethernet.set_ethertype(ethertype);
ethernet.set_payload(payload);
ethernet.packet_size()
}
/// Write a TCP handshake.
pub(crate) fn write_tcp_handshake(&mut self, info: &CapturePacket) {
let (source_port, dest_port) = (info.local_address.port(), info.remote_address.port());
let mut info = info.clone();
info.direction = Direction::Send;
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
// Add a generated comment to all packets
let options = vec![
pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketOption::Comment("Generated TCP handshake".into()),
];
// SYN
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
self.state.send_seq = 500;
tcp.set_sequence(self.state.send_seq);
tcp.set_flags(TcpFlags::SYN);
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options.clone(),
);
// SYN + ACK
info.direction = Direction::Receive;
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
self.state.send_seq = self.state.send_seq.wrapping_add(1);
tcp.set_acknowledgement(self.state.send_seq);
self.state.rec_seq = 1000;
tcp.set_sequence(self.state.rec_seq);
tcp.set_flags(TcpFlags::SYN | TcpFlags::ACK);
tcp.set_source(dest_port);
tcp.set_destination(source_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options.clone(),
);
// ACK
info.direction = Direction::Send;
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_sequence(self.state.send_seq);
self.state.rec_seq = self.state.rec_seq.wrapping_add(1);
tcp.set_acknowledgement(self.state.rec_seq);
tcp.set_flags(TcpFlags::ACK);
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options,
);
self.state.has_sent_handshake = true;
}
pub(crate) fn send_tcp_fin(&mut self, info: &CapturePacket) {
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
let (source_port, dest_port) = info.ports_by_direction();
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_data_offset(5);
tcp.set_window(43440);
match info.direction {
Direction::Send => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
}
Direction::Receive => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
}
}
tcp.set_flags(TcpFlags::FIN | TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
vec![EnhancedPacketOption::Comment("Generated TCP FIN".into())],
);
// Update sequence number
match info.direction {
Direction::Send => {
self.state.send_seq = self.state.send_seq.wrapping_add(1);
}
Direction::Receive => {
self.state.rec_seq = self.state.rec_seq.wrapping_add(1);
}
}
}
fn write_transport_payload(
&mut self,
info: &CapturePacket,
protocol: IpNextHeaderProtocol,
payload: &[u8],
options: Vec<pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketOption>,
) {
let mut network_packet = vec![0; PACKET_SIZE - HEADER_SIZE_ETHERNET];
let (network_size, ethertype) = self.encode_ip_packet(&mut network_packet, info, protocol, payload);
let network_size = network_size + payload.len();
network_packet.truncate(network_size);
let mut physical_packet = vec![0; PACKET_SIZE];
let physical_size =
self.encode_ethernet_packet(&mut physical_packet, ethertype, &network_packet) + network_size;
physical_packet.truncate(physical_size);
self.writer
.write_block(
&pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketBlock {
original_len: physical_size as u32,
data: physical_packet.into(),
interface_id: 0,
timestamp: self.state.start_time.elapsed(),
options,
}
.into_block(),
)
.unwrap();
}
}
impl Default for State {
fn default() -> Self {
Self {
start_time: Instant::now(),
send_seq: 0,
rec_seq: 0,
has_sent_handshake: false,
stream_count: 0,
}
}
}

View file

@ -0,0 +1,214 @@
use std::{marker::PhantomData, net::SocketAddr};
use crate::{
capture::{
packet::CapturePacket,
packet::{Direction, Protocol},
writer::{Writer, CAPTURE_WRITER},
},
protocols::types::TimeoutSettings,
socket::{Socket, TcpSocketImpl, UdpSocketImpl},
GDResult,
};
/// Sets a global capture writer for handling all packet data.
///
/// # Panics
/// Panics if a capture writer is already set.
///
/// # Arguments
/// * `writer` - A boxed writer that implements the `Writer` trait.
pub(crate) fn set_writer(writer: Box<dyn Writer + Send + Sync>) {
let mut lock = CAPTURE_WRITER.lock().unwrap();
if lock.is_some() {
panic!("Capture writer already set");
}
*lock = Some(writer);
}
/// A trait representing a provider of a network protocol.
pub(crate) trait ProtocolProvider {
/// Returns the protocol used by the provider.
fn protocol() -> Protocol;
}
/// Represents the TCP protocol provider.
pub(crate) struct ProtocolTCP;
impl ProtocolProvider for ProtocolTCP {
fn protocol() -> Protocol { Protocol::Tcp }
}
/// Represents the UDP protocol provider.
pub(crate) struct ProtocolUDP;
impl ProtocolProvider for ProtocolUDP {
fn protocol() -> Protocol { Protocol::Udp }
}
/// A socket wrapper that allows capturing packets.
///
/// # Type parameters
/// * `I` - The inner socket type.
/// * `P` - The protocol provider.
#[derive(Clone, Debug)]
pub(crate) struct WrappedCaptureSocket<I: Socket, P: ProtocolProvider> {
inner: I,
remote_address: SocketAddr,
_protocol: PhantomData<P>,
}
impl<I: Socket, P: ProtocolProvider> Socket for WrappedCaptureSocket<I, P> {
/// Creates a new wrapped socket for capturing packets.
///
/// Initializes a new socket of type `I`, wrapping it to enable packet
/// capturing. Capturing is protocol-specific, as indicated by
/// the `ProtocolProvider`.
///
/// # Arguments
/// * `address` - The address to connect the socket to.
/// * `timeout_settings` - Optional timeout settings for the socket.
///
/// # Returns
/// A `GDResult` containing either the wrapped socket or an error.
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self>
where Self: Sized {
let v = Self {
inner: I::new(address, timeout_settings)?,
remote_address: *address,
_protocol: PhantomData,
};
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: address,
local_address: &v.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.new_connect(&info)?;
}
Ok(v)
}
/// Sends data over the socket and captures the packet.
///
/// The method sends data using the inner socket and captures the sent
/// packet if a capture writer is set.
///
/// # Arguments
/// * `data` - Data to be sent.
///
/// # Returns
/// A result indicating success or error in sending data.
fn send(&mut self, data: &[u8]) -> GDResult<()> {
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.write(&info, data)?;
}
self.inner.send(data)
}
/// Receives data from the socket and captures the packet.
///
/// The method receives data using the inner socket and captures the
/// incoming packet if a capture writer is set.
///
/// # Arguments
/// * `size` - Optional size of data to receive.
///
/// # Returns
/// A result containing received data or an error.
fn receive(&mut self, size: Option<usize>) -> crate::GDResult<Vec<u8>> {
let data = self.inner.receive(size)?;
let info = CapturePacket {
direction: Direction::Receive,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.write(&info, &data)?;
}
Ok(data)
}
/// Applies timeout settings to the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Arguments
/// * `timeout_settings` - Optional timeout settings to apply.
///
/// # Returns
/// A result indicating success or error in applying timeouts.
fn apply_timeout(
&self,
timeout_settings: &Option<crate::protocols::types::TimeoutSettings>,
) -> crate::GDResult<()> {
self.inner.apply_timeout(timeout_settings)
}
/// Returns the remote port of the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Returns
/// The remote port number.
fn port(&self) -> u16 { self.inner.port() }
/// Returns the local SocketAddr of the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Returns
/// The local SocketAddr.
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.inner.local_addr() }
}
// this seems a bad way to do this, but its safe
impl<I: Socket, P: ProtocolProvider> Drop for WrappedCaptureSocket<I, P> {
fn drop(&mut self) {
// Construct the CapturePacket info
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self
.local_addr()
.unwrap_or_else(|_| SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)),
};
// If a capture writer is set, close the connection and capture the packet.
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
let _ = writer.close_connection(&info);
}
}
}
/// A specialized `WrappedCaptureSocket` for UDP, using `UdpSocketImpl` as
/// the inner socket and `ProtocolUDP` as the protocol provider.
///
/// This type captures and processes UDP packets, wrapping around standard
/// UDP socket functionalities with additional packet capture
/// capabilities.
pub(crate) type CapturedUdpSocket = WrappedCaptureSocket<UdpSocketImpl, ProtocolUDP>;
/// A specialized `WrappedCaptureSocket` for TCP, using `TcpSocketImpl` as
/// the inner socket and `ProtocolTCP` as the protocol provider.
///
/// This type captures and processes TCP packets, wrapping around standard
/// TCP socket functionalities with additional packet capture
/// capabilities.
pub(crate) type CapturedTcpSocket = WrappedCaptureSocket<TcpSocketImpl, ProtocolTCP>;

View file

@ -0,0 +1,86 @@
use std::{io::Write, sync::Mutex};
use super::{
packet::{CapturePacket, Protocol},
pcap::Pcap,
};
use crate::GDResult;
use lazy_static::lazy_static;
lazy_static! {
/// A globally accessible, lazily-initialized static writer instance.
/// This writer is intended for capturing and recording network packets.
/// The writer is wrapped in a Mutex to ensure thread-safe access and modification.
pub(crate) static ref CAPTURE_WRITER: Mutex<Option<Box<dyn Writer + Send + Sync>>> = Mutex::new(None);
}
/// Trait defining the functionality for a writer that handles network packet
/// captures. This trait includes methods for writing packet data, handling new
/// connections, and closing connections.
pub(crate) trait Writer {
/// Writes a given packet's data to an underlying storage or stream.
///
/// # Arguments
/// * `packet` - Reference to the packet being captured.
/// * `data` - The raw byte data associated with the packet.
///
/// # Returns
/// A `GDResult` indicating the success or failure of the write operation.
fn write(&mut self, packet: &CapturePacket, data: &[u8]) -> GDResult<()>;
/// Handles the creation of a new connection, potentially logging or
/// initializing resources.
///
/// # Arguments
/// * `packet` - Reference to the packet indicating a new connection.
///
/// # Returns
/// A `GDResult` indicating the success or failure of handling the new
/// connection.
fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()>;
/// Closes a connection, handling any necessary cleanup or finalization.
///
/// # Arguments
/// * `packet` - Reference to the packet indicating the closure of a
/// connection.
///
/// # Returns
/// A `GDResult` indicating the success or failure of the connection closure
/// operation.
fn close_connection(&mut self, packet: &CapturePacket) -> GDResult<()>;
}
/// Implementation of the `Writer` trait for the `Pcap` struct.
/// This implementation enables writing, connection handling, and closure
/// specific to PCAP (Packet Capture) format.
impl<W: Write> Writer for Pcap<W> {
fn write(&mut self, info: &CapturePacket, data: &[u8]) -> GDResult<()> {
self.write_transport_packet(info, data);
Ok(())
}
fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()> {
match packet.protocol {
Protocol::Tcp => {
self.write_tcp_handshake(packet);
}
Protocol::Udp => {}
}
self.state.stream_count = self.state.stream_count.wrapping_add(1);
Ok(())
}
fn close_connection(&mut self, packet: &CapturePacket) -> GDResult<()> {
match packet.protocol {
Protocol::Tcp => {
self.send_tcp_fin(packet);
}
Protocol::Udp => {}
}
Ok(())
}
}

View file

@ -16,7 +16,7 @@ pub const MAX_BUFFER_SIZE: usize = 500;
/// Send a ping packet.
///
/// [Reference](https://github.com/Anuken/Mindustry/blob/a2e5fbdedb2fc1c8d3c157bf344d10ad6d321442/core/src/mindustry/net/ArcNetProvider.java#L248)
pub fn send_ping(socket: &mut UdpSocket) -> GDResult<()> { socket.send(&[-2i8 as u8, 1i8 as u8]) }
pub(crate) fn send_ping(socket: &mut UdpSocket) -> GDResult<()> { socket.send(&[-2i8 as u8, 1i8 as u8]) }
/// Parse server data.
///

View file

@ -45,6 +45,9 @@ mod buffer;
mod socket;
mod utils;
#[cfg(feature = "packet_capture")]
pub mod capture;
pub use errors::*;
#[cfg(feature = "games")]
pub use games::*;

View file

@ -4,35 +4,75 @@ use crate::{
GDResult,
};
use std::net::SocketAddr;
use std::{
io::{Read, Write},
net,
net::{self, SocketAddr},
};
const DEFAULT_PACKET_SIZE: usize = 1024;
/// A trait defining the basic functionalities of a network socket.
pub trait Socket {
/// Create a new socket and connect to the remote address (if required).
/// Create a new socket and connect to the remote address.
///
/// Calls [Self::apply_timeout] with the given timeout settings.
/// # Arguments
/// * `address` - The address to connect the socket to.
/// * `timeout_settings` - Optional timeout settings for the socket.
///
/// # Returns
/// A result containing the socket instance or an error.
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self>
where Self: Sized;
/// Apply read and write timeouts to the socket.
///
/// # Arguments
/// * `timeout_settings` - Optional timeout settings to apply.
///
/// # Returns
/// A result indicating success or error in applying timeouts.
fn apply_timeout(&self, timeout_settings: &Option<TimeoutSettings>) -> GDResult<()>;
/// Send data over the socket.
///
/// # Arguments
/// * `data` - Data to be sent.
///
/// # Returns
/// A result indicating success or error in sending data.
fn send(&mut self, data: &[u8]) -> GDResult<()>;
/// Receive data from the socket.
///
/// # Arguments
/// * `size` - Optional size of data to receive.
///
/// # Returns
/// A result containing received data or an error.
fn receive(&mut self, size: Option<usize>) -> GDResult<Vec<u8>>;
/// Get the remote port of the socket.
///
/// # Returns
/// The port number.
fn port(&self) -> u16;
/// Get the local SocketAddr.
///
/// # Returns
/// The local SocketAddr.
fn local_addr(&self) -> std::io::Result<SocketAddr>;
}
pub struct TcpSocket {
/// Implementation of a TCP socket.
pub struct TcpSocketImpl {
/// The underlying TCP socket stream.
socket: net::TcpStream,
/// The address of the remote host.
address: SocketAddr,
}
impl Socket for TcpSocket {
impl Socket for TcpSocketImpl {
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self> {
let socket = TimeoutSettings::get_connect_or_default(timeout_settings).map_or_else(
|| net::TcpStream::connect(address),
@ -72,14 +112,18 @@ impl Socket for TcpSocket {
}
fn port(&self) -> u16 { self.address.port() }
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.socket.local_addr() }
}
pub struct UdpSocket {
/// Implementation of a UDP socket.
pub struct UdpSocketImpl {
/// The underlying UDP socket.
socket: net::UdpSocket,
/// The address of the remote host.
address: SocketAddr,
}
impl Socket for UdpSocket {
impl Socket for UdpSocketImpl {
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self> {
let socket = net::UdpSocket::bind("0.0.0.0:0").map_err(|e| SocketBind.context(e))?;
@ -120,8 +164,19 @@ impl Socket for UdpSocket {
}
fn port(&self) -> u16 { self.address.port() }
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.socket.local_addr() }
}
#[cfg(not(feature = "packet_capture"))]
pub type UdpSocket = UdpSocketImpl;
#[cfg(not(feature = "packet_capture"))]
pub type TcpSocket = TcpSocketImpl;
#[cfg(feature = "packet_capture")]
pub(crate) type UdpSocket = crate::capture::socket::CapturedUdpSocket;
#[cfg(feature = "packet_capture")]
pub(crate) type TcpSocket = crate::capture::socket::CapturedTcpSocket;
#[cfg(test)]
mod tests {
use std::thread;