Better protocol parameters and utils tests

This commit is contained in:
CosminPerRam 2022-10-17 11:11:40 +03:00
parent 3a83588802
commit 544ce897c5
4 changed files with 110 additions and 44 deletions

View file

@ -3,13 +3,13 @@ use std::fmt::Formatter;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum GDError { pub enum GDError {
IDK(String) PacketOverflow
} }
impl fmt::Display for GDError { impl fmt::Display for GDError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
GDError::IDK(details) => write!(f, "IDK: {details}") GDError::PacketOverflow => write!(f, "Packet overflow!")
} }
} }
} }

View file

@ -1,13 +1,13 @@
use crate::errors::GDError; use crate::errors::GDError;
use crate::protocols::valve::{Response, ValveProtocol}; use crate::valve::{Response, ValveProtocol, App};
pub struct TF2; pub struct TF2;
impl TF2 { impl TF2 {
pub fn query(address: &str, port: Option<u16>) -> Result<Response, GDError> { pub fn query(address: &str, port: Option<u16>) -> Result<Response, GDError> {
ValveProtocol::query(address, match port { ValveProtocol::query(App::TF2, address, match port {
None => 27015, None => 27015,
Some(port) => port Some(port) => port
}, false) })
} }
} }

View file

@ -1,6 +1,6 @@
use std::net::UdpSocket; use std::net::UdpSocket;
use crate::errors::GDError; use crate::errors::GDError;
use crate::utils::{buffer, complete_address, concat_u8}; use crate::utils::{buffer, complete_address, concat_u8_arrays};
#[derive(Debug)] #[derive(Debug)]
pub enum Server { pub enum Server {
@ -57,6 +57,12 @@ pub enum Request {
A2sInfo(Option<[u8; 4]>) A2sInfo(Option<[u8; 4]>)
} }
#[derive(PartialEq)]
pub enum App {
TF2 = 440,
TheShip = 2400
}
pub struct ValveProtocol { pub struct ValveProtocol {
socket: UdpSocket, socket: UdpSocket,
complete_address: String complete_address: String
@ -79,7 +85,7 @@ impl ValveProtocol {
let request_kind_packet = match kind { let request_kind_packet = match kind {
Request::A2sInfo(challenge) => match challenge { Request::A2sInfo(challenge) => match challenge {
None => default, None => default,
Some(value) => concat_u8(&default, &value) Some(value) => concat_u8_arrays(&default, &value)
} }
}; };
@ -103,7 +109,7 @@ impl ValveProtocol {
} }
impl ValveProtocol { impl ValveProtocol {
pub(crate) fn query(address: &str, port: u16, has_the_ship: bool) -> Result<Response, GDError> { pub(crate) fn query(app: App, address: &str, port: u16) -> Result<Response, GDError> {
let client = ValveProtocol::new(address, port); let client = ValveProtocol::new(address, port);
client.do_request(Request::A2sInfo(None), None); client.do_request(Request::A2sInfo(None), None);
@ -117,11 +123,11 @@ impl ValveProtocol {
Ok(Response { Ok(Response {
protocol: buffer::get_u8(&buf, &mut pos)?, protocol: buffer::get_u8(&buf, &mut pos)?,
name: buffer::get_string(&buf, &mut pos), name: buffer::get_string(&buf, &mut pos)?,
map: buffer::get_string(&buf, &mut pos), map: buffer::get_string(&buf, &mut pos)?,
folder: buffer::get_string(&buf, &mut pos), folder: buffer::get_string(&buf, &mut pos)?,
game: buffer::get_string(&buf, &mut pos), game: buffer::get_string(&buf, &mut pos)?,
id: buffer::get_u16(&buf, &mut pos), id: buffer::get_u16_le(&buf, &mut pos)?,
players: buffer::get_u8(&buf, &mut pos)?, players: buffer::get_u8(&buf, &mut pos)?,
max_players: buffer::get_u8(&buf, &mut pos)?, max_players: buffer::get_u8(&buf, &mut pos)?,
bots: buffer::get_u8(&buf, &mut pos)?, bots: buffer::get_u8(&buf, &mut pos)?,
@ -137,7 +143,7 @@ impl ValveProtocol {
}, },
has_password: buffer::get_u8(&buf, &mut pos)? == 1, has_password: buffer::get_u8(&buf, &mut pos)? == 1,
vac_secured: buffer::get_u8(&buf, &mut pos)? == 1, vac_secured: buffer::get_u8(&buf, &mut pos)? == 1,
the_ship: match has_the_ship { the_ship: match app == App::TheShip {
false => None, false => None,
true => Some(TheShip { true => Some(TheShip {
mode: buffer::get_u8(&buf, &mut pos)?, mode: buffer::get_u8(&buf, &mut pos)?,
@ -145,33 +151,33 @@ impl ValveProtocol {
duration: buffer::get_u8(&buf, &mut pos)? duration: buffer::get_u8(&buf, &mut pos)?
}) })
}, },
version: buffer::get_string(&buf, &mut pos), version: buffer::get_string(&buf, &mut pos)?,
extra_data: match buffer::get_u8(&buf, &mut pos) { extra_data: match buffer::get_u8(&buf, &mut pos) {
Err(_) => None, Err(_) => None,
Ok(value) => Some(ExtraData { Ok(value) => Some(ExtraData {
port: match (value & 0x80) > 0 { port: match (value & 0x80) > 0 {
false => None, false => None,
true => Some(buffer::get_u16(&buf, &mut pos)) true => Some(buffer::get_u16_le(&buf, &mut pos)?)
}, },
steam_id: match (value & 0x10) > 0 { steam_id: match (value & 0x10) > 0 {
false => None, false => None,
true => Some(buffer::get_u64(&buf, &mut pos)) true => Some(buffer::get_u64_le(&buf, &mut pos)?)
}, },
tv_port: match (value & 0x40) > 0 { tv_port: match (value & 0x40) > 0 {
false => None, false => None,
true => Some(buffer::get_u16(&buf, &mut pos)) true => Some(buffer::get_u16_le(&buf, &mut pos)?)
}, },
tv_name: match (value & 0x40) > 0 { tv_name: match (value & 0x40) > 0 {
false => None, false => None,
true => Some(buffer::get_string(&buf, &mut pos)) true => Some(buffer::get_string(&buf, &mut pos)?)
}, },
keywords: match (value & 0x20) > 0 { keywords: match (value & 0x20) > 0 {
false => None, false => None,
true => Some(buffer::get_string(&buf, &mut pos)) true => Some(buffer::get_string(&buf, &mut pos)?)
}, },
game_id: match (value & 0x01) > 0 { game_id: match (value & 0x01) > 0 {
false => None, false => None,
true => Some(buffer::get_u64(&buf, &mut pos)) true => Some(buffer::get_u64_le(&buf, &mut pos)?)
} }
}) })
} }

View file

@ -1,23 +1,19 @@
use std::ops::Add; use std::ops::Add;
use crate::GDError; use crate::GDError;
pub fn concat_u8(first: &[u8], second: &[u8]) -> Vec<u8> { pub fn concat_u8_arrays(first: &[u8], second: &[u8]) -> Vec<u8> {
[first, second].concat() [first, second].concat()
} }
pub fn find_first_string(arr: &[u8]) -> String {
std::str::from_utf8(&arr[..arr.iter().position(|&x| x == 0).unwrap()]).unwrap().to_string()
}
pub fn complete_address(address: &str, port: u16) -> String { pub fn complete_address(address: &str, port: u16) -> String {
String::from(address.to_owned() + ":").add(&*port.to_string()) String::from(address.to_owned() + ":").add(&*port.to_string())
} }
pub fn combine_two_u8(high: u8, low: u8) -> u16 { pub fn combine_two_u8_le(high: u8, low: u8) -> u16 {
u16::from_be_bytes([high, low]) u16::from_be_bytes([high, low])
} }
pub fn combine_eight_u8(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 { pub fn combine_eight_u8_le(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 {
u64::from_be_bytes([a, b, c, d, e, f, g, h]) u64::from_be_bytes([a, b, c, d, e, f, g, h])
} }
@ -25,43 +21,57 @@ pub mod buffer {
use super::*; use super::*;
pub fn get_u8(buf: &[u8], pos: &mut usize) -> Result<u8, GDError> { pub fn get_u8(buf: &[u8], pos: &mut usize) -> Result<u8, GDError> {
if buf.len() <= *pos {
return Err(GDError::PacketOverflow);
}
let value = buf[*pos]; let value = buf[*pos];
*pos += 1; *pos += 1;
Ok(value) Ok(value)
} }
pub fn get_u16(buf: &[u8], pos: &mut usize) -> u16 { pub fn get_u16_le(buf: &[u8], pos: &mut usize) -> Result<u16, GDError> {
let value = combine_two_u8(buf[*pos + 1], buf[*pos]); if buf.len() <= *pos + 1 {
return Err(GDError::PacketOverflow);
}
let value = combine_two_u8_le(buf[*pos + 1], buf[*pos]);
*pos += 2; *pos += 2;
value Ok(value)
} }
pub fn get_u64(buf: &[u8], pos: &mut usize) -> u64 { pub fn get_u64_le(buf: &[u8], pos: &mut usize) -> Result<u64, GDError> {
let value = combine_eight_u8(buf[*pos + 7], buf[*pos + 6], buf[*pos + 5], buf[*pos + 4], buf[*pos + 3], buf[*pos + 2], buf[*pos + 1], buf[*pos]); if buf.len() <= *pos + 7 {
return Err(GDError::PacketOverflow);
}
let value = combine_eight_u8_le(buf[*pos + 7], buf[*pos + 6], buf[*pos + 5], buf[*pos + 4], buf[*pos + 3], buf[*pos + 2], buf[*pos + 1], buf[*pos]);
*pos += 8; *pos += 8;
value Ok(value)
} }
pub fn get_string(buf: &[u8], pos: &mut usize) -> String { pub fn get_string(buf: &[u8], pos: &mut usize) -> Result<String, GDError> {
let value = find_first_string(&buf[*pos..]); let sub_buf = &buf[*pos..];
let first_null_position = sub_buf.iter().position(|&x| x == 0).ok_or(GDError::PacketOverflow)?;
let value = std::str::from_utf8(&sub_buf[..first_null_position]).unwrap().to_string();
*pos += value.len() + 1; *pos += value.len() + 1;
value Ok(value)
} }
} }
#[cfg(test)] #[cfg(test)]
mod utils { mod tests {
use super::*; use super::*;
#[test] #[test]
fn concat_u8_test() { fn concat_u8_arrays_test() {
let a: [u8; 2] = [1, 2]; let a: [u8; 2] = [1, 2];
let b: [u8; 2] = [3, 4]; let b: [u8; 2] = [3, 4];
let combined = concat_u8(&a, &b); let combined = concat_u8_arrays(&a, &b);
assert_eq!(a[0], combined[0]); assert_eq!(combined[0], a[0]);
assert_eq!(a[1], combined[1]); assert_eq!(combined[1], a[1]);
assert_eq!(b[0], combined[2]); assert_eq!(combined[2], b[0]);
assert_eq!(b[1], combined[3]); assert_eq!(combined[3], b[1]);
} }
#[test] #[test]
@ -70,4 +80,54 @@ mod utils {
let port = 27015; let port = 27015;
assert_eq!(complete_address(address, port), "192.168.0.1:27015"); assert_eq!(complete_address(address, port), "192.168.0.1:27015");
} }
#[test]
fn combine_two_u8_le_test() {
assert_eq!(combine_two_u8_le(49, 123), 12667);
}
#[test]
fn combine_eight_u8_le_test() {
assert_eq!(combine_eight_u8_le(49, 123, 44, 52, 3, 250, 22, 0), 3565492131910522368);
}
#[test]
fn get_u8_test() {
let data = [72];
let mut pos = 0;
assert_eq!(buffer::get_u8(&data, &mut pos).unwrap(), 72);
assert_eq!(pos, 1);
assert!(buffer::get_u8(&data, &mut pos).is_err());
assert_eq!(pos, 1);
}
#[test]
fn get_u16_le_test() {
let data = [72, 29];
let mut pos = 0;
assert_eq!(buffer::get_u16_le(&data, &mut pos).unwrap(), 7496);
assert_eq!(pos, 2);
assert!(buffer::get_u16_le(&data, &mut pos).is_err());
assert_eq!(pos, 2);
}
#[test]
fn get_u64_le_test() {
let data = [72, 29, 128, 99, 69, 4, 2, 0];
let mut pos = 0;
assert_eq!(buffer::get_u64_le(&data, &mut pos).unwrap(), 567646022016328);
assert_eq!(pos, 8);
assert!(buffer::get_u64_le(&data, &mut pos).is_err());
assert_eq!(pos, 8);
}
#[test]
fn get_string_test() {
let data = [72, 101, 108, 108, 111, 0, 72];
let mut pos = 0;
assert_eq!(buffer::get_string(&data, &mut pos).unwrap(), "Hello");
assert_eq!(pos, 6);
assert!(buffer::get_string(&data, &mut pos).is_err());
assert_eq!(pos, 6);
}
} }