rust-gamedig/src/socket.rs
Tom c3281be419
[Protocol] Retry failed requests (#95)
* Add retry count to TimeoutSettings

This can be used to specify how many times to re-send requests that
fail. The default value is "1" so the if the first request fails, 1 more
attempt will be made.

* Add retries to valve queries

* [Protocol] &Optional<TimeoutSettings> add get_retries_or_default

Allow fetching the number of retries or the default retries value from a
borrowed optional TimeoutSettings.

* [Protocol] Add retries to minecraft protocol

* [Protocol] Add retries to quake

* [Protocol] Add retries to gamespy

* [Protocol] Update TimeoutSettings docs, and change default retries to 0

* Remove logging from retry_on_timeout

* [Protocol] TimeoutSettings make retries non-optional

* [Protocol] Move retry logic into lower level query functions

Retries are now implemented as wrappers on the single function that
would need to be retried on timeout.

In order to avoid cloning of TimeoutSettings, Socket::apply_timeouts()
was changed to accept a borrowed TimeoutSettings. And extra helpers were
added to the TimeoutSettings impl to reduce repetition.

* [Examples] Add retries to the generic example

* Also retry on PacketSend error

Sending packets could also timeout and until error_generic_member_access
is stable we have no way of determining the type of the underlying
`std::error::Error`.

* Add retry unit tests

* [Docs] Update changelog
2023-09-25 22:12:54 +03:00

177 lines
5.4 KiB
Rust

use crate::{
protocols::types::TimeoutSettings,
GDErrorKind::{PacketReceive, PacketSend, SocketBind, SocketConnect},
GDResult,
};
use std::net::SocketAddr;
use std::{
io::{Read, Write},
net,
};
const DEFAULT_PACKET_SIZE: usize = 1024;
pub trait Socket {
fn new(address: &SocketAddr) -> GDResult<Self>
where Self: Sized;
fn apply_timeout(&self, timeout_settings: &Option<TimeoutSettings>) -> GDResult<()>;
fn send(&mut self, data: &[u8]) -> GDResult<()>;
fn receive(&mut self, size: Option<usize>) -> GDResult<Vec<u8>>;
fn port(&self) -> u16;
}
pub struct TcpSocket {
socket: net::TcpStream,
address: SocketAddr,
}
impl Socket for TcpSocket {
fn new(address: &SocketAddr) -> GDResult<Self> {
Ok(Self {
socket: net::TcpStream::connect(address).map_err(|e| SocketConnect.context(e))?,
address: *address,
})
}
fn apply_timeout(&self, timeout_settings: &Option<TimeoutSettings>) -> GDResult<()> {
let (read, write) = TimeoutSettings::get_read_and_write_or_defaults(timeout_settings);
self.socket.set_read_timeout(read).unwrap(); // unwrapping because TimeoutSettings::new
self.socket.set_write_timeout(write).unwrap(); // checks if these are 0 and throws an error
Ok(())
}
fn send(&mut self, data: &[u8]) -> GDResult<()> {
self.socket.write(data).map_err(|e| PacketSend.context(e))?;
Ok(())
}
fn receive(&mut self, size: Option<usize>) -> GDResult<Vec<u8>> {
let mut buf = Vec::with_capacity(size.unwrap_or(DEFAULT_PACKET_SIZE));
self.socket
.read_to_end(&mut buf)
.map_err(|e| PacketReceive.context(e))?;
Ok(buf)
}
fn port(&self) -> u16 { self.address.port() }
}
pub struct UdpSocket {
socket: net::UdpSocket,
address: SocketAddr,
}
impl Socket for UdpSocket {
fn new(address: &SocketAddr) -> GDResult<Self> {
let socket = net::UdpSocket::bind("0.0.0.0:0").map_err(|e| SocketBind.context(e))?;
Ok(Self {
socket,
address: *address,
})
}
fn apply_timeout(&self, timeout_settings: &Option<TimeoutSettings>) -> GDResult<()> {
let (read, write) = TimeoutSettings::get_read_and_write_or_defaults(timeout_settings);
self.socket.set_read_timeout(read).unwrap(); // unwrapping because TimeoutSettings::new
self.socket.set_write_timeout(write).unwrap(); // checks if these are 0 and throws an error
Ok(())
}
fn send(&mut self, data: &[u8]) -> GDResult<()> {
self.socket
.send_to(data, self.address)
.map_err(|e| PacketSend.context(e))?;
Ok(())
}
fn receive(&mut self, size: Option<usize>) -> GDResult<Vec<u8>> {
let mut buf: Vec<u8> = vec![0; size.unwrap_or(DEFAULT_PACKET_SIZE)];
let (number_of_bytes_received, _) = self
.socket
.recv_from(&mut buf)
.map_err(|e| PacketReceive.context(e))?;
Ok(buf[.. number_of_bytes_received].to_vec())
}
fn port(&self) -> u16 { self.address.port() }
}
#[cfg(test)]
mod tests {
use std::thread;
use super::*;
#[test]
fn test_tcp_socket_send_and_receive() {
// Spawn a thread to run the server
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
let bound_address = listener.local_addr().unwrap();
let server_thread = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = [0; 1024];
let _ = stream.read(&mut buf).unwrap();
let _ = stream.write(&buf).unwrap();
});
// Create a TCP socket and send a message to the server
let mut socket = TcpSocket::new(&bound_address).unwrap();
let message = b"hello, world!";
socket.send(message).unwrap();
// Receive the response from the server
let received_message: Vec<u8> = socket
.receive(None)
.unwrap()
// Iterate over the buffer and remove 0s that are alone in the buffer
// just added to pass default size
.into_iter()
.filter(|&x| x != 0)
.collect();
server_thread.join().expect("server thread panicked");
assert_eq!(message, &received_message[..]);
}
#[test]
fn test_udp_socket_send_and_receive() {
// Spawn a thread to run the server
let socket = net::UdpSocket::bind("127.0.0.1:0").unwrap();
let bound_address = socket.local_addr().unwrap();
let server_thread = thread::spawn(move || {
let mut buf = [0; 1024];
let (_, src_addr) = socket.recv_from(&mut buf).unwrap();
socket.send_to(&buf, src_addr).unwrap();
});
// Create a UDP socket and send a message to the server
let mut socket = UdpSocket::new(&bound_address).unwrap();
let message = b"hello, world!";
socket.send(message).unwrap();
// Receive the response from the server
let received_message: Vec<u8> = socket
.receive(None)
.unwrap()
// Iterate over the buffer and remove 0s that are alone in the buffer
// just added to pass default size
.into_iter()
.filter(|&x| x != 0)
.collect();
server_thread.join().expect("server thread panicked");
assert_eq!(message, &received_message[..]);
}
}