feat: Add packet capture functionality and many more CLI improvements (#182)

This commit is contained in:
Cain 2024-02-07 22:31:31 +00:00 committed by GitHub
commit e86e80522b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1445 additions and 82 deletions

31
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,31 @@
name: Release
on:
release:
types: [created]
jobs:
release:
name: Release ${{ matrix.target }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-pc-windows-gnu
archive: zip
- target: x86_64-unknown-linux-musl
archive: tar.gz tar.xz tar.zst
- target: x86_64-apple-darwin
archive: zip
- target: wasm32-wasi
archive: zip tar.gz
steps:
- uses: actions/checkout@v4
- name: Compile and release
uses: rust-build/rust-build.action@v1.4.4
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
RUSTTARGET: ${{ matrix.target }}
ARCHIVE_TYPES: ${{ matrix.archive }}

View file

@ -8,7 +8,8 @@ resolver = "2"
opt-level = 3
debug = false
rpath = true
lto = true
lto = 'fat'
codegen-units = 1
[profile.release.package."*"]
opt-level = 3

View file

@ -12,14 +12,44 @@ default-run = "gamedig-cli"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["json"]
default = ["json", "bson", "xml", "browser"]
# Tools
packet_capture = ["gamedig/packet_capture"]
# Output formats
bson = ["dep:serde", "dep:bson", "dep:hex", "dep:base64", "gamedig/serde"]
json = ["dep:serde", "dep:serde_json", "gamedig/serde"]
xml = ["dep:serde", "dep:serde_json", "dep:quick-xml", "gamedig/serde"]
# Misc
browser = ["dep:webbrowser"]
[dependencies]
clap = { version = "4.1.11", features = ["derive"] }
gamedig = { version = "*", path = "../lib", features = ["clap"] }
# Core Dependencies
thiserror = "1.0.43"
clap = { version = "4.1.11", default-features = false, features = ["derive"] }
gamedig = { version = "*", path = "../lib", default-features = false, features = [
"clap",
"games",
"game_defs",
] }
# Feature Dependencies
# Serialization / Deserialization
serde = { version = "1", optional = true, default-features = false }
# BSON
bson = { version = "2.8.1", optional = true, default-features = false }
base64 = { version = "0.21.7", optional = true, default-features = false, features = ["std"]}
hex = { version = "0.4.3", optional = true, default-features = false }
# JSON
serde_json = { version = "1", optional = true, default-features = false }
# XML
quick-xml = { version = "0.31.0", optional = true, default-features = false }
# Browser
webbrowser = { version = "0.8.12", optional = true, default-features = false }
# JSON dependencies
serde = { version = "1", optional = true }
serde_json = { version = "1", optional = true }

View file

@ -11,6 +11,18 @@ pub enum Error {
#[error("Gamedig Error: {0}")]
Gamedig(#[from] gamedig::errors::GDError),
#[cfg(any(feature = "json", feature = "xml"))]
#[error("Serde Error: {0}")]
Serde(#[from] serde_json::Error),
#[cfg(feature = "bson")]
#[error("Bson Error: {0}")]
Bson(#[from] bson::ser::Error),
#[cfg(feature = "xml")]
#[error("Xml Error: {0}")]
Xml(#[from] quick_xml::Error),
#[error("Unknown Game: {0}")]
UnknownGame(String),

View file

@ -1,47 +1,91 @@
use std::net::{IpAddr, ToSocketAddrs};
use clap::{Parser, ValueEnum};
use gamedig::{games::*, protocols::types::CommonResponse, ExtraRequestSettings, TimeoutSettings};
use clap::{Parser, Subcommand, ValueEnum};
use gamedig::{
games::*,
protocols::types::{CommonResponse, ExtraRequestSettings, TimeoutSettings},
};
mod error;
use self::error::{Error, Result};
const GAMEDIG_HEADER: &str = r"
_____ _____ _ _____ _ _____
/ ____| | __ \(_) / ____| | |_ _|
| | __ __ _ _ __ ___ ___| | | |_ __ _ | | | | | |
| | |_ |/ _` | '_ ` _ \ / _ \ | | | |/ _` | | | | | | |
| |__| | (_| | | | | | | __/ |__| | | (_| | | |____| |____ _| |_
\_____|\__,_|_| |_| |_|\___|_____/|_|\__, | \_____|______|_____|
__/ |
|___/
A command line interface for querying game servers.
Copyright (C) 2022 - 2024 GameDig Organization & Contributors
Licensed under the MIT license
";
// NOTE: For some reason without setting long_about here the doc comment for
// ExtraRequestSettings gets set as the about for the CLI.
#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
#[command(author, version, about = GAMEDIG_HEADER, long_about = None)]
struct Cli {
/// Unique identifier of the game for which server information is being
/// queried.
#[arg(short, long)]
game: String,
#[command(subcommand)]
action: Action,
}
/// Hostname or IP address of the server.
#[arg(short, long)]
ip: String,
#[derive(Subcommand, Debug)]
enum Action {
/// Query game server information
Query {
/// Unique identifier of the game for which server information is being
/// queried.
#[arg(short, long)]
game: String,
/// Optional query port number for the server. If not provided the default
/// port for the game is used.
#[arg(short, long)]
port: Option<u16>,
/// Hostname or IP address of the server.
#[arg(short, long)]
ip: String,
/// Flag indicating if the output should be in JSON format.
#[cfg(feature = "json")]
#[arg(short, long)]
json: bool,
/// Optional query port number for the server. If not provided the
/// default port for the game is used.
#[arg(short, long)]
port: Option<u16>,
/// Which response variant to use when outputting.
#[arg(short, long, default_value = "generic")]
output_mode: OutputMode,
/// Specifies the output format
#[arg(short, long, default_value = "debug", value_enum)]
format: OutputFormat,
/// Optional timeout settings for the server query.
#[command(flatten, next_help_heading = "Timeouts")]
timeout_settings: Option<TimeoutSettings>,
/// Which response variant to use when outputting
#[arg(short, long, default_value = "generic")]
output_mode: OutputMode,
/// Optional extra settings for the server query.
#[command(flatten, next_help_heading = "Query options")]
extra_options: Option<ExtraRequestSettings>,
/// Optional file path for packet capture file writer
///
/// When set a PCAP file will be written to the location. This file can
/// be read with a tool like wireshark. The PCAP contains a log of the
/// TCP and UDP data sent/recieved by the gamedig library, it does not
/// contain an accurate representation of the real packets sent on the
/// wire as some information has to be hallucinated in order for it to
/// display nicely.
#[cfg(feature = "packet_capture")]
#[arg(short, long)]
capture: Option<std::path::PathBuf>,
/// Optional timeout settings for the server query
#[command(flatten, next_help_heading = "Timeouts")]
timeout_settings: Option<TimeoutSettings>,
/// Optional extra settings for the server query
#[command(flatten, next_help_heading = "Query options")]
extra_options: Option<ExtraRequestSettings>,
},
/// Check out the source code
Source,
/// Display the MIT License information
License,
}
#[derive(Clone, Debug, PartialEq, Eq, ValueEnum)]
@ -54,6 +98,27 @@ enum OutputMode {
ProtocolSpecific,
}
#[derive(Clone, Debug, PartialEq, Eq, ValueEnum)]
enum OutputFormat {
/// Human readable structured output
Debug,
/// RFC 8259
#[cfg(feature = "json")]
JsonPretty,
/// RFC 8259
#[cfg(feature = "json")]
Json,
/// Parser tries to be mostly XML 1.1 (RFC 7303) compliant
#[cfg(feature = "xml")]
Xml,
/// RFC 4648 section 8
#[cfg(feature = "bson")]
BsonHex,
/// RFC 4648 section 4
#[cfg(feature = "bson")]
BsonBase64,
}
/// Attempt to find a game from the [library game definitions](GAMES) based on
/// its unique identifier.
///
@ -81,15 +146,14 @@ fn find_game(game_id: &str) -> Result<&'static Game> {
/// # Returns
/// * `Result<IpAddr>` - On sucess returns a resolved IP address; on failure
/// returns an [Error::InvalidHostname] error.
fn resolve_ip_or_domain(host: &str, extra_options: &mut Option<ExtraRequestSettings>) -> Result<IpAddr> {
host.parse().map_or_else(
|_| {
set_hostname_if_missing(host, extra_options);
resolve_domain(host)
},
Ok,
)
fn resolve_ip_or_domain<T: AsRef<str>>(host: T, extra_options: &mut Option<ExtraRequestSettings>) -> Result<IpAddr> {
let host_str = host.as_ref();
if let Ok(parsed_ip) = host_str.parse() {
Ok(parsed_ip)
} else {
set_hostname_if_missing(host_str, extra_options);
resolve_domain(host_str)
}
}
/// Resolve a domain name to one of its IP addresses (the first one returned).
@ -133,15 +197,49 @@ fn set_hostname_if_missing(host: &str, extra_options: &mut Option<ExtraRequestSe
/// # Arguments
/// * `args` - A reference to the command line options.
/// * `result` - A reference to the result of the query.
fn output_result(args: &Cli, result: &dyn CommonResponse) {
match args.output_mode {
fn output_result<T: CommonResponse + ?Sized>(output_mode: OutputMode, format: OutputFormat, result: &T) {
match format {
OutputFormat::Debug => {
match output_mode {
OutputMode::Generic => output_result_debug(result.as_json()),
OutputMode::ProtocolSpecific => output_result_debug(result.as_original()),
};
}
#[cfg(feature = "json")]
OutputMode::Generic if args.json => output_result_json(result.as_json()),
OutputFormat::JsonPretty => {
let _ = match output_mode {
OutputMode::Generic => output_result_json_pretty(result.as_json()),
OutputMode::ProtocolSpecific => output_result_json_pretty(result.as_original()),
};
}
#[cfg(feature = "json")]
OutputMode::ProtocolSpecific if args.json => output_result_json(result.as_original()),
OutputMode::Generic => output_result_debug(result.as_json()),
OutputMode::ProtocolSpecific => output_result_debug(result.as_original()),
OutputFormat::Json => {
let _ = match output_mode {
OutputMode::Generic => output_result_json(result.as_json()),
OutputMode::ProtocolSpecific => output_result_json(result.as_original()),
};
}
#[cfg(feature = "xml")]
OutputFormat::Xml => {
let _ = match output_mode {
OutputMode::Generic => output_result_xml(result.as_json()),
OutputMode::ProtocolSpecific => output_result_xml(result.as_original()),
};
}
#[cfg(feature = "bson")]
OutputFormat::BsonHex => {
let _ = match output_mode {
OutputMode::Generic => output_result_bson_hex(result.as_json()),
OutputMode::ProtocolSpecific => output_result_bson_hex(result.as_original()),
};
}
#[cfg(feature = "bson")]
OutputFormat::BsonBase64 => {
let _ = match output_mode {
OutputMode::Generic => output_result_bson_base64(result.as_json()),
OutputMode::ProtocolSpecific => output_result_bson_base64(result.as_original()),
};
}
}
}
@ -158,29 +256,232 @@ fn output_result_debug<R: std::fmt::Debug>(result: R) {
/// # Arguments
/// * `result` - A serde serializable result.
#[cfg(feature = "json")]
fn output_result_json<R: serde::Serialize>(result: R) {
serde_json::to_writer_pretty(std::io::stdout(), &result).unwrap();
}
fn main() -> Result<()> {
// Parse the command line arguments
let args = Cli::parse();
// Retrieve the game based on the provided ID
let game = find_game(&args.game)?;
// Extract extra options for use in setup
let mut extra_options = args.extra_options.clone();
// Resolve the IP address
let ip = resolve_ip_or_domain(&args.ip, &mut extra_options)?;
// Query the server using game definition, parsed IP, and user command line
// flags.
let result = query_with_timeout_and_extra_settings(game, &ip, args.port, args.timeout_settings, extra_options)?;
// Output the result in the specified format
output_result(&args, result.as_ref());
fn output_result_json<T: serde::Serialize>(result: T) -> Result<()> {
println!("{}", serde_json::to_string(&result)?);
Ok(())
}
/// Output the result as a pretty printed JSON object.
///
/// # Arguments
/// * `result` - A serde serializable result.
#[cfg(feature = "json")]
fn output_result_json_pretty<T: serde::Serialize>(result: T) -> Result<()> {
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}
/// Output the result as an XML object.
/// # Arguments
/// * `result` - A serde serializable result.
#[cfg(feature = "xml")]
fn output_result_xml<T: serde::Serialize>(result: T) -> Result<()> {
use quick_xml::{
events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event},
Writer,
};
use serde_json::Value;
// Serialize the input `result` of generic type `T` into a JSON value.
// This step converts the Rust data structure into a JSON format,
// which will then be used to generate the corresponding XML.
let json = serde_json::to_value(result)?;
// Initialize the XML writer with a new, empty vector to store the XML data.
let mut writer = Writer::new(Vec::new());
// Write the XML 1.1 declaration
writer.write_event(Event::Decl(BytesDecl::new("1.1", Some("utf-8"), None)))?;
// Define a recursive function `json_to_xml` to convert the JSON value into XML
// format. The function takes a mutable reference to the XML writer, an
// optional key as a string slice, and a reference to the JSON value to be
// converted.
fn json_to_xml<W: std::io::Write>(writer: &mut Writer<W>, key: Option<&str>, value: &Value) -> Result<()> {
match value {
// If the JSON value is an object, iterate through its properties,
// creating XML elements with corresponding keys and values.
Value::Object(obj) => {
if let Some(key) = key {
// Start an XML element for the object.
writer.write_event(Event::Start(BytesStart::new(key)))?;
}
for (k, v) in obj {
// Recursively process each property of the object.
json_to_xml(writer, Some(k), v)?;
}
if let Some(key) = key {
// Close the XML element for the object.
writer.write_event(Event::End(BytesEnd::new(key)))?;
}
}
// If the JSON value is an array, iterate through its elements,
// creating XML elements for each item.
Value::Array(arr) => {
for v in arr {
// Use "item" as the default key for array elements without keys.
json_to_xml(writer, key.or(Some("item")), v)?;
}
}
// If the JSON value is null, create an empty XML element.
Value::Null => {
if let Some(key) = key {
writer.write_event(Event::Empty(BytesStart::new(key)))?;
}
}
// For all other JSON value types (String, Number, Bool),
// convert the value to a string and create an XML element with the text content.
// Note: We handle null strings here as well, as they are treated as a string type.
_ => {
if let Some(key) = key {
// Start the XML element with the given key.
writer.write_event(Event::Start(BytesStart::new(key)))?;
}
// Convert the JSON value to a string, trimming quotes for non-string values.
let text_string = match value {
Value::String(s) => s.to_string(),
_ => value.to_string().trim_matches('"').to_string(),
};
// Create a text node with the converted string value.
writer.write_event(Event::Text(BytesText::new(&text_string)))?;
if let Some(key) = key {
// Close the XML element.
writer.write_event(Event::End(BytesEnd::new(key)))?;
}
}
}
Ok(())
}
// Start the root XML element named "data".
writer.write_event(Event::Start(BytesStart::new("data")))?;
// Convert the top-level JSON value to XML.
json_to_xml(&mut writer, None, &json)?;
// Close the root XML element.
writer.write_event(Event::End(BytesEnd::new("data")))?;
// Convert the XML data stored in the writer to a UTF-8 string.
let xml_bytes = writer.into_inner();
let xml_string = String::from_utf8(xml_bytes).expect("Failed to convert XML bytes to UTF-8 string");
println!("{}", xml_string);
Ok(())
}
/// Output the result as a BSON object encoded as a hex string.
///
/// # Arguments
/// * `result` - A serde serializable result.
#[cfg(feature = "bson")]
fn output_result_bson_hex<T: serde::Serialize>(result: T) -> Result<()> {
let bson = bson::to_bson(&result)?;
if let bson::Bson::Document(document) = bson {
let bytes = bson::to_vec(&document)?;
println!("{}", hex::encode(bytes));
Ok(())
} else {
panic!("Failed to convert result to BSON Hex (BSON_DOCUMENT_UNAVAILABLE)");
}
}
/// Output the result as a BSON object encoded as a base64 string.
///
/// # Arguments
/// * `result` - A serde serializable result.
#[cfg(feature = "bson")]
fn output_result_bson_base64<T: serde::Serialize>(result: T) -> Result<()> {
use base64::Engine;
let bson = bson::to_bson(&result)?;
if let bson::Bson::Document(document) = bson {
let bytes = bson::to_vec(&document)?;
println!("{}", base64::prelude::BASE64_STANDARD.encode(bytes));
Ok(())
} else {
panic!("Failed to convert result to BSON Base64 (BSON_DOCUMENT_UNAVAILABLE)");
}
}
fn main() -> Result<()> {
let args = Cli::parse();
match args.action {
Action::Query {
game,
ip,
port,
format,
output_mode,
#[cfg(feature = "packet_capture")]
capture,
timeout_settings,
extra_options,
} => {
// Process the query command
let game = find_game(&game)?;
let mut extra_options = extra_options;
let ip = resolve_ip_or_domain(&ip, &mut extra_options)?;
#[cfg(feature = "packet_capture")]
gamedig::capture::setup_capture(capture);
let result = query_with_timeout_and_extra_settings(game, &ip, port, timeout_settings, extra_options)?;
output_result(output_mode, format, result.as_ref());
}
Action::Source => {
println!("{}", GAMEDIG_HEADER);
#[cfg(feature = "browser")]
{
// Directly offering to open the URL
println!("\nWould you like to open the GitHub repository in your default browser? [Y/n]");
let mut choice = String::new();
std::io::stdin().read_line(&mut choice).unwrap();
if choice.trim().eq_ignore_ascii_case("Y") {
if webbrowser::open("https://github.com/gamedig/rust-gamedig").is_ok() {
println!("Opening GitHub repository in default browser...");
} else {
println!("Failed to open GitHub repository in default browser.");
println!("Please use the following URL: https://github.com/gamedig/rust-gamedig");
}
} else {
println!("Not to worry, you can always open the repository manually");
println!("by visiting the following URL: https://github.com/gamedig/rust-gamedig");
}
}
#[cfg(not(feature = "browser"))]
{
println!("\nYou can find the source code for this project at the following URL:");
println!("https://github.com/gamedig/rust-gamedig");
}
println!("\nBe sure to leave a star if you like the project :)");
}
Action::License => {
// Bake the license into the binary
// so we don't have to ship it separately
println!("{}", include_str!("../../../LICENSE.md"));
}
}
Ok(())
}

View file

@ -23,6 +23,7 @@ services = []
game_defs = ["dep:phf", "games"]
serde = ["dep:serde", "serde/derive"]
clap = ["dep:clap"]
packet_capture = ["dep:pcap-file", "dep:pnet_packet", "dep:lazy_static"]
[dependencies]
byteorder = "1.5"
@ -37,6 +38,10 @@ phf = { version = "0.11", optional = true, features = ["macros"] }
clap = { version = "4.1.11", optional = true, features = ["derive"] }
pcap-file = { version = "2.0", optional = true }
pnet_packet = { version = "0.34", optional = true }
lazy_static = { version = "1.4", optional = true }
[dev-dependencies]
gamedig-id-tests = { path = "../id-tests", default-features = false }

View file

@ -0,0 +1,39 @@
pub(crate) mod packet;
mod pcap;
pub(crate) mod socket;
pub(crate) mod writer;
use self::{pcap::Pcap, writer::Writer};
use pcap_file::pcapng::{blocks::interface_description::InterfaceDescriptionBlock, PcapNgBlock, PcapNgWriter};
use std::path::PathBuf;
pub fn setup_capture(file_path: Option<PathBuf>) {
if let Some(file_path) = file_path {
let file = std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(file_path.with_extension("pcap"))
.unwrap();
let mut pcap_writer = PcapNgWriter::new(file).unwrap();
// Write headers
let _ = pcap_writer.write_block(
&InterfaceDescriptionBlock {
linktype: pcap_file::DataLink::ETHERNET,
snaplen: 0xFFFF,
options: vec![],
}
.into_block(),
);
let writer = Box::new(Pcap::new(pcap_writer));
attach(writer)
}
}
/// Attaches a writer to the capture module.
///
/// # Errors
/// Returns an Error if the writer is already set.
fn attach(writer: Box<dyn Writer + Send + Sync>) { crate::capture::socket::set_writer(writer); }

View file

@ -0,0 +1,203 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
/// Size of a standard network packet.
pub(crate) const PACKET_SIZE: usize = 5012;
/// Size of an Ethernet header.
pub(crate) const HEADER_SIZE_ETHERNET: usize = 14;
/// Size of an IPv4 header.
pub(crate) const HEADER_SIZE_IP4: usize = 20;
/// Size of an IPv6 header.
pub(crate) const HEADER_SIZE_IP6: usize = 40;
/// Size of a UDP header.
pub(crate) const HEADER_SIZE_UDP: usize = 4;
/// Represents the direction of a network packet.
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) enum Direction {
/// Packet is outgoing (sent by us).
Send,
/// Packet is incoming (received by us).
Receive,
}
/// Defines the protocol of a network packet.
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) enum Protocol {
/// Transmission Control Protocol.
Tcp,
/// User Datagram Protocol.
Udp,
}
/// Trait for handling different types of IP addresses (IPv4, IPv6).
pub(crate) trait IpAddress: Sized {
/// Creates an instance from a standard `IpAddr`, returning `None` if the
/// types are incompatible.
fn from_std(ip: IpAddr) -> Option<Self>;
}
/// Represents a captured network packet with metadata.
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct CapturePacket<'a> {
/// Direction of the packet (Send/Receive).
pub(crate) direction: Direction,
/// Protocol of the packet (Tcp/UDP).
pub(crate) protocol: Protocol,
/// Remote socket address.
pub(crate) remote_address: &'a SocketAddr,
/// Local socket address.
pub(crate) local_address: &'a SocketAddr,
}
impl CapturePacket<'_> {
/// Retrieves the local and remote ports based on the packet's direction.
///
/// Returns:
/// - (u16, u16): Tuple of (source port, destination port).
pub(super) fn ports_by_direction(&self) -> (u16, u16) {
let (local, remote) = (self.local_address.port(), self.remote_address.port());
self.direction.order(local, remote)
}
/// Retrieves the local and remote IP addresses.
///
/// Returns:
/// - (IpAddr, IpAddr): Tuple of (local IP, remote IP).
pub(super) fn ip_addr(&self) -> (IpAddr, IpAddr) {
let (local, remote) = (self.local_address.ip(), self.remote_address.ip());
(local, remote)
}
/// Retrieves IP addresses of a specific type (IPv4 or IPv6) based on the
/// packet's direction.
///
/// Panics if the IP type of the addresses does not match the requested
/// type.
///
/// Returns:
/// - (T, T): Tuple of (source IP, destination IP) of the specified type in
/// order.
pub(super) fn ipvt_by_direction<T: IpAddress>(&self) -> (T, T) {
let (local, remote) = (
T::from_std(self.local_address.ip()).expect("Incorrect IP type for local address"),
T::from_std(self.remote_address.ip()).expect("Incorrect IP type for remote address"),
);
self.direction.order(local, remote)
}
}
impl Direction {
/// Orders two elements (source and destination) based on the packet's
/// direction.
///
/// Returns:
/// - (T, T): Ordered tuple (source, destination).
pub(self) const fn order<T>(&self, source: T, remote: T) -> (T, T) {
match self {
Direction::Send => (source, remote),
Direction::Receive => (remote, source),
}
}
}
/// Implements the `IpAddress` trait for `Ipv4Addr`.
impl IpAddress for Ipv4Addr {
/// Creates an `Ipv4Addr` from a standard `IpAddr`, if it's IPv4.
fn from_std(ip: IpAddr) -> Option<Self> {
match ip {
IpAddr::V4(ipv4) => Some(ipv4),
_ => None,
}
}
}
/// Implements the `IpAddress` trait for `Ipv6Addr`.
impl IpAddress for Ipv6Addr {
/// Creates an `Ipv6Addr` from a standard `IpAddr`, if it's IPv6.
fn from_std(ip: IpAddr) -> Option<Self> {
match ip {
IpAddr::V6(ipv6) => Some(ipv6),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// Helper function to create a SocketAddr from a string
fn socket_addr(addr: &str) -> SocketAddr { SocketAddr::from_str(addr).unwrap() }
#[test]
fn test_ports_by_direction() {
let packet_send = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let packet_receive = CapturePacket {
direction: Direction::Receive,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
assert_eq!(packet_send.ports_by_direction(), (8080, 80));
assert_eq!(packet_receive.ports_by_direction(), (80, 8080));
}
#[test]
fn test_ip_addr() {
let packet_send = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let packet_receive = CapturePacket {
direction: Direction::Receive,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
assert_eq!(
packet_send.ip_addr(),
(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
)
);
assert_eq!(
packet_receive.ip_addr(),
(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
)
);
}
#[test]
fn test_ip_by_direction_type_specific() {
let packet = CapturePacket {
direction: Direction::Send,
protocol: Protocol::Tcp,
local_address: &socket_addr("127.0.0.1:8080"),
remote_address: &socket_addr("192.168.1.1:80"),
};
let ipv4_result: Result<(Ipv4Addr, Ipv4Addr), _> =
std::panic::catch_unwind(|| packet.ipvt_by_direction::<Ipv4Addr>());
assert!(ipv4_result.is_ok());
let ipv6_result: Result<(Ipv6Addr, Ipv6Addr), _> =
std::panic::catch_unwind(|| packet.ipvt_by_direction::<Ipv6Addr>());
assert!(ipv6_result.is_err());
}
}

View file

@ -0,0 +1,383 @@
use pcap_file::pcapng::{blocks::enhanced_packet::EnhancedPacketOption, PcapNgBlock, PcapNgWriter};
use pnet_packet::{
ethernet::{EtherType, MutableEthernetPacket},
ip::{IpNextHeaderProtocol, IpNextHeaderProtocols},
ipv4::MutableIpv4Packet,
ipv6::MutableIpv6Packet,
tcp::{MutableTcpPacket, TcpFlags},
udp::MutableUdpPacket,
PacketSize,
};
use std::{io::Write, net::IpAddr, time::Instant};
use super::packet::{
CapturePacket,
Direction,
Protocol,
HEADER_SIZE_ETHERNET,
HEADER_SIZE_IP4,
HEADER_SIZE_IP6,
HEADER_SIZE_UDP,
PACKET_SIZE,
};
const BUFFER_SIZE: usize = PACKET_SIZE - HEADER_SIZE_IP6 - HEADER_SIZE_ETHERNET;
pub(crate) struct Pcap<W: Write> {
writer: PcapNgWriter<W>,
pub(crate) state: State,
}
pub(crate) struct State {
pub(crate) start_time: Instant,
pub(crate) send_seq: u32,
pub(crate) rec_seq: u32,
pub(crate) has_sent_handshake: bool,
pub(crate) stream_count: u32,
}
impl<W: Write> Pcap<W> {
pub(crate) fn new(writer: PcapNgWriter<W>) -> Self {
Self {
writer,
state: State::default(),
}
}
pub(crate) fn write_transport_packet(&mut self, info: &CapturePacket, payload: &[u8]) {
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
let (source_port, dest_port) = info.ports_by_direction();
match info.protocol {
Protocol::Tcp => {
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_payload(payload);
tcp.set_data_offset(5);
tcp.set_window(43440);
match info.direction {
Direction::Send => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
self.state.send_seq = self.state.send_seq.wrapping_add(payload.len() as u32);
}
Direction::Receive => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
self.state.rec_seq = self.state.rec_seq.wrapping_add(payload.len() as u32);
}
}
tcp.set_flags(TcpFlags::PSH | TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size + payload.len()],
vec![],
);
let mut info = info.clone();
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(dest_port);
tcp.set_destination(source_port);
tcp.set_data_offset(5);
tcp.set_window(43440);
match &info.direction {
Direction::Send => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
info.direction = Direction::Receive;
}
Direction::Receive => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
info.direction = Direction::Send;
}
}
tcp.set_flags(TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
vec![EnhancedPacketOption::Comment("Generated TCP ACK".into())],
);
}
Protocol::Udp => {
let buf_size = {
let mut udp = MutableUdpPacket::new(buf).unwrap();
udp.set_source(source_port);
udp.set_destination(dest_port);
udp.set_length((payload.len() + HEADER_SIZE_UDP) as u16);
udp.set_payload(payload);
udp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Udp,
&buf[.. buf_size + payload.len()],
vec![],
);
}
}
}
/// Encode a network layer (IP) packet with a payload.
fn encode_ip_packet(
&self,
buf: &mut [u8],
info: &CapturePacket,
protocol: IpNextHeaderProtocol,
payload: &[u8],
) -> (usize, EtherType) {
match info.ip_addr() {
(IpAddr::V4(_), IpAddr::V4(_)) => {
let (source, destination) = info.ipvt_by_direction();
let header_size = HEADER_SIZE_IP4 + (32 / 8);
let mut ip = MutableIpv4Packet::new(buf).unwrap();
ip.set_version(4);
ip.set_total_length((payload.len() + header_size) as u16);
ip.set_next_level_protocol(protocol);
// https://en.wikipedia.org/wiki/Internet_Protocol_version_4#Total_Length
ip.set_header_length((header_size / 4) as u8);
ip.set_source(source);
ip.set_destination(destination);
ip.set_payload(payload);
ip.set_ttl(64);
ip.set_flags(pnet_packet::ipv4::Ipv4Flags::DontFragment);
let mut options_writer =
pnet_packet::ipv4::MutableIpv4OptionPacket::new(ip.get_options_raw_mut()).unwrap();
options_writer.set_copied(1);
options_writer.set_class(0);
options_writer.set_number(pnet_packet::ipv4::Ipv4OptionNumbers::SID);
options_writer.set_length(&[4]);
options_writer.set_data(&(self.state.stream_count as u16).to_be_bytes());
ip.set_checksum(pnet_packet::ipv4::checksum(&ip.to_immutable()));
(ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv4)
}
(IpAddr::V6(_), IpAddr::V6(_)) => {
let (source, destination) = info.ipvt_by_direction();
let mut ip = MutableIpv6Packet::new(buf).unwrap();
ip.set_version(6);
ip.set_payload_length(payload.len() as u16);
ip.set_next_header(protocol);
ip.set_source(source);
ip.set_destination(destination);
ip.set_hop_limit(64);
ip.set_payload(payload);
ip.set_flow_label(self.state.stream_count);
(ip.packet_size(), pnet_packet::ethernet::EtherTypes::Ipv6)
}
_ => unreachable!(),
}
}
/// Encode a physical layer (ethernet) packet with a payload.
fn encode_ethernet_packet(
&self,
buf: &mut [u8],
ethertype: pnet_packet::ethernet::EtherType,
payload: &[u8],
) -> usize {
let mut ethernet = MutableEthernetPacket::new(buf).unwrap();
ethernet.set_ethertype(ethertype);
ethernet.set_payload(payload);
ethernet.packet_size()
}
/// Write a TCP handshake.
pub(crate) fn write_tcp_handshake(&mut self, info: &CapturePacket) {
let (source_port, dest_port) = (info.local_address.port(), info.remote_address.port());
let mut info = info.clone();
info.direction = Direction::Send;
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
// Add a generated comment to all packets
let options = vec![
pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketOption::Comment("Generated TCP handshake".into()),
];
// SYN
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
self.state.send_seq = 500;
tcp.set_sequence(self.state.send_seq);
tcp.set_flags(TcpFlags::SYN);
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options.clone(),
);
// SYN + ACK
info.direction = Direction::Receive;
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
self.state.send_seq = self.state.send_seq.wrapping_add(1);
tcp.set_acknowledgement(self.state.send_seq);
self.state.rec_seq = 1000;
tcp.set_sequence(self.state.rec_seq);
tcp.set_flags(TcpFlags::SYN | TcpFlags::ACK);
tcp.set_source(dest_port);
tcp.set_destination(source_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options.clone(),
);
// ACK
info.direction = Direction::Send;
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_sequence(self.state.send_seq);
self.state.rec_seq = self.state.rec_seq.wrapping_add(1);
tcp.set_acknowledgement(self.state.rec_seq);
tcp.set_flags(TcpFlags::ACK);
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_window(43440);
tcp.set_data_offset(5);
tcp.packet_size()
};
self.write_transport_payload(
&info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
options,
);
self.state.has_sent_handshake = true;
}
pub(crate) fn send_tcp_fin(&mut self, info: &CapturePacket) {
let mut buffer_array: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let buf: &mut [u8] = &mut buffer_array[..];
let (source_port, dest_port) = info.ports_by_direction();
let buf_size = {
let mut tcp = MutableTcpPacket::new(buf).unwrap();
tcp.set_source(source_port);
tcp.set_destination(dest_port);
tcp.set_data_offset(5);
tcp.set_window(43440);
match info.direction {
Direction::Send => {
tcp.set_sequence(self.state.send_seq);
tcp.set_acknowledgement(self.state.rec_seq);
}
Direction::Receive => {
tcp.set_sequence(self.state.rec_seq);
tcp.set_acknowledgement(self.state.send_seq);
}
}
tcp.set_flags(TcpFlags::FIN | TcpFlags::ACK);
tcp.packet_size()
};
self.write_transport_payload(
info,
IpNextHeaderProtocols::Tcp,
&buf[.. buf_size],
vec![EnhancedPacketOption::Comment("Generated TCP FIN".into())],
);
// Update sequence number
match info.direction {
Direction::Send => {
self.state.send_seq = self.state.send_seq.wrapping_add(1);
}
Direction::Receive => {
self.state.rec_seq = self.state.rec_seq.wrapping_add(1);
}
}
}
fn write_transport_payload(
&mut self,
info: &CapturePacket,
protocol: IpNextHeaderProtocol,
payload: &[u8],
options: Vec<pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketOption>,
) {
let mut network_packet = vec![0; PACKET_SIZE - HEADER_SIZE_ETHERNET];
let (network_size, ethertype) = self.encode_ip_packet(&mut network_packet, info, protocol, payload);
let network_size = network_size + payload.len();
network_packet.truncate(network_size);
let mut physical_packet = vec![0; PACKET_SIZE];
let physical_size =
self.encode_ethernet_packet(&mut physical_packet, ethertype, &network_packet) + network_size;
physical_packet.truncate(physical_size);
self.writer
.write_block(
&pcap_file::pcapng::blocks::enhanced_packet::EnhancedPacketBlock {
original_len: physical_size as u32,
data: physical_packet.into(),
interface_id: 0,
timestamp: self.state.start_time.elapsed(),
options,
}
.into_block(),
)
.unwrap();
}
}
impl Default for State {
fn default() -> Self {
Self {
start_time: Instant::now(),
send_seq: 0,
rec_seq: 0,
has_sent_handshake: false,
stream_count: 0,
}
}
}

View file

@ -0,0 +1,214 @@
use std::{marker::PhantomData, net::SocketAddr};
use crate::{
capture::{
packet::CapturePacket,
packet::{Direction, Protocol},
writer::{Writer, CAPTURE_WRITER},
},
protocols::types::TimeoutSettings,
socket::{Socket, TcpSocketImpl, UdpSocketImpl},
GDResult,
};
/// Sets a global capture writer for handling all packet data.
///
/// # Panics
/// Panics if a capture writer is already set.
///
/// # Arguments
/// * `writer` - A boxed writer that implements the `Writer` trait.
pub(crate) fn set_writer(writer: Box<dyn Writer + Send + Sync>) {
let mut lock = CAPTURE_WRITER.lock().unwrap();
if lock.is_some() {
panic!("Capture writer already set");
}
*lock = Some(writer);
}
/// A trait representing a provider of a network protocol.
pub(crate) trait ProtocolProvider {
/// Returns the protocol used by the provider.
fn protocol() -> Protocol;
}
/// Represents the TCP protocol provider.
pub(crate) struct ProtocolTCP;
impl ProtocolProvider for ProtocolTCP {
fn protocol() -> Protocol { Protocol::Tcp }
}
/// Represents the UDP protocol provider.
pub(crate) struct ProtocolUDP;
impl ProtocolProvider for ProtocolUDP {
fn protocol() -> Protocol { Protocol::Udp }
}
/// A socket wrapper that allows capturing packets.
///
/// # Type parameters
/// * `I` - The inner socket type.
/// * `P` - The protocol provider.
#[derive(Clone, Debug)]
pub(crate) struct WrappedCaptureSocket<I: Socket, P: ProtocolProvider> {
inner: I,
remote_address: SocketAddr,
_protocol: PhantomData<P>,
}
impl<I: Socket, P: ProtocolProvider> Socket for WrappedCaptureSocket<I, P> {
/// Creates a new wrapped socket for capturing packets.
///
/// Initializes a new socket of type `I`, wrapping it to enable packet
/// capturing. Capturing is protocol-specific, as indicated by
/// the `ProtocolProvider`.
///
/// # Arguments
/// * `address` - The address to connect the socket to.
/// * `timeout_settings` - Optional timeout settings for the socket.
///
/// # Returns
/// A `GDResult` containing either the wrapped socket or an error.
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self>
where Self: Sized {
let v = Self {
inner: I::new(address, timeout_settings)?,
remote_address: *address,
_protocol: PhantomData,
};
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: address,
local_address: &v.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.new_connect(&info)?;
}
Ok(v)
}
/// Sends data over the socket and captures the packet.
///
/// The method sends data using the inner socket and captures the sent
/// packet if a capture writer is set.
///
/// # Arguments
/// * `data` - Data to be sent.
///
/// # Returns
/// A result indicating success or error in sending data.
fn send(&mut self, data: &[u8]) -> GDResult<()> {
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.write(&info, data)?;
}
self.inner.send(data)
}
/// Receives data from the socket and captures the packet.
///
/// The method receives data using the inner socket and captures the
/// incoming packet if a capture writer is set.
///
/// # Arguments
/// * `size` - Optional size of data to receive.
///
/// # Returns
/// A result containing received data or an error.
fn receive(&mut self, size: Option<usize>) -> crate::GDResult<Vec<u8>> {
let data = self.inner.receive(size)?;
let info = CapturePacket {
direction: Direction::Receive,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self.local_addr().unwrap(),
};
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
writer.write(&info, &data)?;
}
Ok(data)
}
/// Applies timeout settings to the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Arguments
/// * `timeout_settings` - Optional timeout settings to apply.
///
/// # Returns
/// A result indicating success or error in applying timeouts.
fn apply_timeout(
&self,
timeout_settings: &Option<crate::protocols::types::TimeoutSettings>,
) -> crate::GDResult<()> {
self.inner.apply_timeout(timeout_settings)
}
/// Returns the remote port of the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Returns
/// The remote port number.
fn port(&self) -> u16 { self.inner.port() }
/// Returns the local SocketAddr of the wrapped socket.
///
/// Delegates the operation to the inner socket implementation.
///
/// # Returns
/// The local SocketAddr.
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.inner.local_addr() }
}
// this seems a bad way to do this, but its safe
impl<I: Socket, P: ProtocolProvider> Drop for WrappedCaptureSocket<I, P> {
fn drop(&mut self) {
// Construct the CapturePacket info
let info = CapturePacket {
direction: Direction::Send,
protocol: P::protocol(),
remote_address: &self.remote_address,
local_address: &self
.local_addr()
.unwrap_or_else(|_| SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)),
};
// If a capture writer is set, close the connection and capture the packet.
if let Some(writer) = CAPTURE_WRITER.lock().unwrap().as_mut() {
let _ = writer.close_connection(&info);
}
}
}
/// A specialized `WrappedCaptureSocket` for UDP, using `UdpSocketImpl` as
/// the inner socket and `ProtocolUDP` as the protocol provider.
///
/// This type captures and processes UDP packets, wrapping around standard
/// UDP socket functionalities with additional packet capture
/// capabilities.
pub(crate) type CapturedUdpSocket = WrappedCaptureSocket<UdpSocketImpl, ProtocolUDP>;
/// A specialized `WrappedCaptureSocket` for TCP, using `TcpSocketImpl` as
/// the inner socket and `ProtocolTCP` as the protocol provider.
///
/// This type captures and processes TCP packets, wrapping around standard
/// TCP socket functionalities with additional packet capture
/// capabilities.
pub(crate) type CapturedTcpSocket = WrappedCaptureSocket<TcpSocketImpl, ProtocolTCP>;

View file

@ -0,0 +1,86 @@
use std::{io::Write, sync::Mutex};
use super::{
packet::{CapturePacket, Protocol},
pcap::Pcap,
};
use crate::GDResult;
use lazy_static::lazy_static;
lazy_static! {
/// A globally accessible, lazily-initialized static writer instance.
/// This writer is intended for capturing and recording network packets.
/// The writer is wrapped in a Mutex to ensure thread-safe access and modification.
pub(crate) static ref CAPTURE_WRITER: Mutex<Option<Box<dyn Writer + Send + Sync>>> = Mutex::new(None);
}
/// Trait defining the functionality for a writer that handles network packet
/// captures. This trait includes methods for writing packet data, handling new
/// connections, and closing connections.
pub(crate) trait Writer {
/// Writes a given packet's data to an underlying storage or stream.
///
/// # Arguments
/// * `packet` - Reference to the packet being captured.
/// * `data` - The raw byte data associated with the packet.
///
/// # Returns
/// A `GDResult` indicating the success or failure of the write operation.
fn write(&mut self, packet: &CapturePacket, data: &[u8]) -> GDResult<()>;
/// Handles the creation of a new connection, potentially logging or
/// initializing resources.
///
/// # Arguments
/// * `packet` - Reference to the packet indicating a new connection.
///
/// # Returns
/// A `GDResult` indicating the success or failure of handling the new
/// connection.
fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()>;
/// Closes a connection, handling any necessary cleanup or finalization.
///
/// # Arguments
/// * `packet` - Reference to the packet indicating the closure of a
/// connection.
///
/// # Returns
/// A `GDResult` indicating the success or failure of the connection closure
/// operation.
fn close_connection(&mut self, packet: &CapturePacket) -> GDResult<()>;
}
/// Implementation of the `Writer` trait for the `Pcap` struct.
/// This implementation enables writing, connection handling, and closure
/// specific to PCAP (Packet Capture) format.
impl<W: Write> Writer for Pcap<W> {
fn write(&mut self, info: &CapturePacket, data: &[u8]) -> GDResult<()> {
self.write_transport_packet(info, data);
Ok(())
}
fn new_connect(&mut self, packet: &CapturePacket) -> GDResult<()> {
match packet.protocol {
Protocol::Tcp => {
self.write_tcp_handshake(packet);
}
Protocol::Udp => {}
}
self.state.stream_count = self.state.stream_count.wrapping_add(1);
Ok(())
}
fn close_connection(&mut self, packet: &CapturePacket) -> GDResult<()> {
match packet.protocol {
Protocol::Tcp => {
self.send_tcp_fin(packet);
}
Protocol::Udp => {}
}
Ok(())
}
}

View file

@ -16,7 +16,7 @@ pub const MAX_BUFFER_SIZE: usize = 500;
/// Send a ping packet.
///
/// [Reference](https://github.com/Anuken/Mindustry/blob/a2e5fbdedb2fc1c8d3c157bf344d10ad6d321442/core/src/mindustry/net/ArcNetProvider.java#L248)
pub fn send_ping(socket: &mut UdpSocket) -> GDResult<()> { socket.send(&[-2i8 as u8, 1i8 as u8]) }
pub(crate) fn send_ping(socket: &mut UdpSocket) -> GDResult<()> { socket.send(&[-2i8 as u8, 1i8 as u8]) }
/// Parse server data.
///

View file

@ -45,6 +45,9 @@ mod buffer;
mod socket;
mod utils;
#[cfg(feature = "packet_capture")]
pub mod capture;
pub use errors::*;
#[cfg(feature = "games")]
pub use games::*;

View file

@ -4,35 +4,75 @@ use crate::{
GDResult,
};
use std::net::SocketAddr;
use std::{
io::{Read, Write},
net,
net::{self, SocketAddr},
};
const DEFAULT_PACKET_SIZE: usize = 1024;
/// A trait defining the basic functionalities of a network socket.
pub trait Socket {
/// Create a new socket and connect to the remote address (if required).
/// Create a new socket and connect to the remote address.
///
/// Calls [Self::apply_timeout] with the given timeout settings.
/// # Arguments
/// * `address` - The address to connect the socket to.
/// * `timeout_settings` - Optional timeout settings for the socket.
///
/// # Returns
/// A result containing the socket instance or an error.
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self>
where Self: Sized;
/// Apply read and write timeouts to the socket.
///
/// # Arguments
/// * `timeout_settings` - Optional timeout settings to apply.
///
/// # Returns
/// A result indicating success or error in applying timeouts.
fn apply_timeout(&self, timeout_settings: &Option<TimeoutSettings>) -> GDResult<()>;
/// Send data over the socket.
///
/// # Arguments
/// * `data` - Data to be sent.
///
/// # Returns
/// A result indicating success or error in sending data.
fn send(&mut self, data: &[u8]) -> GDResult<()>;
/// Receive data from the socket.
///
/// # Arguments
/// * `size` - Optional size of data to receive.
///
/// # Returns
/// A result containing received data or an error.
fn receive(&mut self, size: Option<usize>) -> GDResult<Vec<u8>>;
/// Get the remote port of the socket.
///
/// # Returns
/// The port number.
fn port(&self) -> u16;
/// Get the local SocketAddr.
///
/// # Returns
/// The local SocketAddr.
fn local_addr(&self) -> std::io::Result<SocketAddr>;
}
pub struct TcpSocket {
/// Implementation of a TCP socket.
pub struct TcpSocketImpl {
/// The underlying TCP socket stream.
socket: net::TcpStream,
/// The address of the remote host.
address: SocketAddr,
}
impl Socket for TcpSocket {
impl Socket for TcpSocketImpl {
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self> {
let socket = TimeoutSettings::get_connect_or_default(timeout_settings).map_or_else(
|| net::TcpStream::connect(address),
@ -72,14 +112,18 @@ impl Socket for TcpSocket {
}
fn port(&self) -> u16 { self.address.port() }
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.socket.local_addr() }
}
pub struct UdpSocket {
/// Implementation of a UDP socket.
pub struct UdpSocketImpl {
/// The underlying UDP socket.
socket: net::UdpSocket,
/// The address of the remote host.
address: SocketAddr,
}
impl Socket for UdpSocket {
impl Socket for UdpSocketImpl {
fn new(address: &SocketAddr, timeout_settings: &Option<TimeoutSettings>) -> GDResult<Self> {
let socket = net::UdpSocket::bind("0.0.0.0:0").map_err(|e| SocketBind.context(e))?;
@ -120,8 +164,19 @@ impl Socket for UdpSocket {
}
fn port(&self) -> u16 { self.address.port() }
fn local_addr(&self) -> std::io::Result<SocketAddr> { self.socket.local_addr() }
}
#[cfg(not(feature = "packet_capture"))]
pub type UdpSocket = UdpSocketImpl;
#[cfg(not(feature = "packet_capture"))]
pub type TcpSocket = TcpSocketImpl;
#[cfg(feature = "packet_capture")]
pub(crate) type UdpSocket = crate::capture::socket::CapturedUdpSocket;
#[cfg(feature = "packet_capture")]
pub(crate) type TcpSocket = crate::capture::socket::CapturedTcpSocket;
#[cfg(test)]
mod tests {
use std::thread;