From 544ce897c545e29bd994c2d443d2f163baeca5bb Mon Sep 17 00:00:00 2001 From: CosminPerRam Date: Mon, 17 Oct 2022 11:11:40 +0300 Subject: [PATCH] Better protocol parameters and utils tests --- src/errors.rs | 4 +- src/games/tf2.rs | 6 +-- src/protocols/valve.rs | 38 ++++++++------- src/utils.rs | 106 ++++++++++++++++++++++++++++++++--------- 4 files changed, 110 insertions(+), 44 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 1b0623e..85784b2 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -3,13 +3,13 @@ use std::fmt::Formatter; #[derive(Debug, Clone)] pub enum GDError { - IDK(String) + PacketOverflow } impl fmt::Display for GDError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - GDError::IDK(details) => write!(f, "IDK: {details}") + GDError::PacketOverflow => write!(f, "Packet overflow!") } } } diff --git a/src/games/tf2.rs b/src/games/tf2.rs index d349f0f..bf2461a 100644 --- a/src/games/tf2.rs +++ b/src/games/tf2.rs @@ -1,13 +1,13 @@ use crate::errors::GDError; -use crate::protocols::valve::{Response, ValveProtocol}; +use crate::valve::{Response, ValveProtocol, App}; pub struct TF2; impl TF2 { pub fn query(address: &str, port: Option) -> Result { - ValveProtocol::query(address, match port { + ValveProtocol::query(App::TF2, address, match port { None => 27015, Some(port) => port - }, false) + }) } } diff --git a/src/protocols/valve.rs b/src/protocols/valve.rs index e1e2486..cee509a 100644 --- a/src/protocols/valve.rs +++ b/src/protocols/valve.rs @@ -1,6 +1,6 @@ use std::net::UdpSocket; use crate::errors::GDError; -use crate::utils::{buffer, complete_address, concat_u8}; +use crate::utils::{buffer, complete_address, concat_u8_arrays}; #[derive(Debug)] pub enum Server { @@ -57,6 +57,12 @@ pub enum Request { A2sInfo(Option<[u8; 4]>) } +#[derive(PartialEq)] +pub enum App { + TF2 = 440, + TheShip = 2400 +} + pub struct ValveProtocol { socket: UdpSocket, complete_address: String @@ -79,7 +85,7 @@ impl ValveProtocol { let request_kind_packet = match kind { Request::A2sInfo(challenge) => match challenge { None => default, - Some(value) => concat_u8(&default, &value) + Some(value) => concat_u8_arrays(&default, &value) } }; @@ -103,7 +109,7 @@ impl ValveProtocol { } impl ValveProtocol { - pub(crate) fn query(address: &str, port: u16, has_the_ship: bool) -> Result { + pub(crate) fn query(app: App, address: &str, port: u16) -> Result { let client = ValveProtocol::new(address, port); client.do_request(Request::A2sInfo(None), None); @@ -117,11 +123,11 @@ impl ValveProtocol { Ok(Response { protocol: buffer::get_u8(&buf, &mut pos)?, - name: buffer::get_string(&buf, &mut pos), - map: buffer::get_string(&buf, &mut pos), - folder: buffer::get_string(&buf, &mut pos), - game: buffer::get_string(&buf, &mut pos), - id: buffer::get_u16(&buf, &mut pos), + name: buffer::get_string(&buf, &mut pos)?, + map: buffer::get_string(&buf, &mut pos)?, + folder: buffer::get_string(&buf, &mut pos)?, + game: buffer::get_string(&buf, &mut pos)?, + id: buffer::get_u16_le(&buf, &mut pos)?, players: buffer::get_u8(&buf, &mut pos)?, max_players: buffer::get_u8(&buf, &mut pos)?, bots: buffer::get_u8(&buf, &mut pos)?, @@ -137,7 +143,7 @@ impl ValveProtocol { }, has_password: buffer::get_u8(&buf, &mut pos)? == 1, vac_secured: buffer::get_u8(&buf, &mut pos)? == 1, - the_ship: match has_the_ship { + the_ship: match app == App::TheShip { false => None, true => Some(TheShip { mode: buffer::get_u8(&buf, &mut pos)?, @@ -145,33 +151,33 @@ impl ValveProtocol { duration: buffer::get_u8(&buf, &mut pos)? }) }, - version: buffer::get_string(&buf, &mut pos), + version: buffer::get_string(&buf, &mut pos)?, extra_data: match buffer::get_u8(&buf, &mut pos) { Err(_) => None, Ok(value) => Some(ExtraData { port: match (value & 0x80) > 0 { false => None, - true => Some(buffer::get_u16(&buf, &mut pos)) + true => Some(buffer::get_u16_le(&buf, &mut pos)?) }, steam_id: match (value & 0x10) > 0 { false => None, - true => Some(buffer::get_u64(&buf, &mut pos)) + true => Some(buffer::get_u64_le(&buf, &mut pos)?) }, tv_port: match (value & 0x40) > 0 { false => None, - true => Some(buffer::get_u16(&buf, &mut pos)) + true => Some(buffer::get_u16_le(&buf, &mut pos)?) }, tv_name: match (value & 0x40) > 0 { false => None, - true => Some(buffer::get_string(&buf, &mut pos)) + true => Some(buffer::get_string(&buf, &mut pos)?) }, keywords: match (value & 0x20) > 0 { false => None, - true => Some(buffer::get_string(&buf, &mut pos)) + true => Some(buffer::get_string(&buf, &mut pos)?) }, game_id: match (value & 0x01) > 0 { false => None, - true => Some(buffer::get_u64(&buf, &mut pos)) + true => Some(buffer::get_u64_le(&buf, &mut pos)?) } }) } diff --git a/src/utils.rs b/src/utils.rs index 6e5ee2c..2ce307f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,23 +1,19 @@ use std::ops::Add; use crate::GDError; -pub fn concat_u8(first: &[u8], second: &[u8]) -> Vec { +pub fn concat_u8_arrays(first: &[u8], second: &[u8]) -> Vec { [first, second].concat() } -pub fn find_first_string(arr: &[u8]) -> String { - std::str::from_utf8(&arr[..arr.iter().position(|&x| x == 0).unwrap()]).unwrap().to_string() -} - pub fn complete_address(address: &str, port: u16) -> String { String::from(address.to_owned() + ":").add(&*port.to_string()) } -pub fn combine_two_u8(high: u8, low: u8) -> u16 { +pub fn combine_two_u8_le(high: u8, low: u8) -> u16 { u16::from_be_bytes([high, low]) } -pub fn combine_eight_u8(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 { +pub fn combine_eight_u8_le(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 { u64::from_be_bytes([a, b, c, d, e, f, g, h]) } @@ -25,43 +21,57 @@ pub mod buffer { use super::*; pub fn get_u8(buf: &[u8], pos: &mut usize) -> Result { + if buf.len() <= *pos { + return Err(GDError::PacketOverflow); + } + let value = buf[*pos]; *pos += 1; Ok(value) } - pub fn get_u16(buf: &[u8], pos: &mut usize) -> u16 { - let value = combine_two_u8(buf[*pos + 1], buf[*pos]); + pub fn get_u16_le(buf: &[u8], pos: &mut usize) -> Result { + if buf.len() <= *pos + 1 { + return Err(GDError::PacketOverflow); + } + + let value = combine_two_u8_le(buf[*pos + 1], buf[*pos]); *pos += 2; - value + Ok(value) } - pub fn get_u64(buf: &[u8], pos: &mut usize) -> u64 { - let value = combine_eight_u8(buf[*pos + 7], buf[*pos + 6], buf[*pos + 5], buf[*pos + 4], buf[*pos + 3], buf[*pos + 2], buf[*pos + 1], buf[*pos]); + pub fn get_u64_le(buf: &[u8], pos: &mut usize) -> Result { + if buf.len() <= *pos + 7 { + return Err(GDError::PacketOverflow); + } + + let value = combine_eight_u8_le(buf[*pos + 7], buf[*pos + 6], buf[*pos + 5], buf[*pos + 4], buf[*pos + 3], buf[*pos + 2], buf[*pos + 1], buf[*pos]); *pos += 8; - value + Ok(value) } - pub fn get_string(buf: &[u8], pos: &mut usize) -> String { - let value = find_first_string(&buf[*pos..]); + pub fn get_string(buf: &[u8], pos: &mut usize) -> Result { + let sub_buf = &buf[*pos..]; + let first_null_position = sub_buf.iter().position(|&x| x == 0).ok_or(GDError::PacketOverflow)?; + let value = std::str::from_utf8(&sub_buf[..first_null_position]).unwrap().to_string(); *pos += value.len() + 1; - value + Ok(value) } } #[cfg(test)] -mod utils { +mod tests { use super::*; #[test] - fn concat_u8_test() { + fn concat_u8_arrays_test() { let a: [u8; 2] = [1, 2]; let b: [u8; 2] = [3, 4]; - let combined = concat_u8(&a, &b); - assert_eq!(a[0], combined[0]); - assert_eq!(a[1], combined[1]); - assert_eq!(b[0], combined[2]); - assert_eq!(b[1], combined[3]); + let combined = concat_u8_arrays(&a, &b); + assert_eq!(combined[0], a[0]); + assert_eq!(combined[1], a[1]); + assert_eq!(combined[2], b[0]); + assert_eq!(combined[3], b[1]); } #[test] @@ -70,4 +80,54 @@ mod utils { let port = 27015; assert_eq!(complete_address(address, port), "192.168.0.1:27015"); } + + #[test] + fn combine_two_u8_le_test() { + assert_eq!(combine_two_u8_le(49, 123), 12667); + } + + #[test] + fn combine_eight_u8_le_test() { + assert_eq!(combine_eight_u8_le(49, 123, 44, 52, 3, 250, 22, 0), 3565492131910522368); + } + + #[test] + fn get_u8_test() { + let data = [72]; + let mut pos = 0; + assert_eq!(buffer::get_u8(&data, &mut pos).unwrap(), 72); + assert_eq!(pos, 1); + assert!(buffer::get_u8(&data, &mut pos).is_err()); + assert_eq!(pos, 1); + } + + #[test] + fn get_u16_le_test() { + let data = [72, 29]; + let mut pos = 0; + assert_eq!(buffer::get_u16_le(&data, &mut pos).unwrap(), 7496); + assert_eq!(pos, 2); + assert!(buffer::get_u16_le(&data, &mut pos).is_err()); + assert_eq!(pos, 2); + } + + #[test] + fn get_u64_le_test() { + let data = [72, 29, 128, 99, 69, 4, 2, 0]; + let mut pos = 0; + assert_eq!(buffer::get_u64_le(&data, &mut pos).unwrap(), 567646022016328); + assert_eq!(pos, 8); + assert!(buffer::get_u64_le(&data, &mut pos).is_err()); + assert_eq!(pos, 8); + } + + #[test] + fn get_string_test() { + let data = [72, 101, 108, 108, 111, 0, 72]; + let mut pos = 0; + assert_eq!(buffer::get_string(&data, &mut pos).unwrap(), "Hello"); + assert_eq!(pos, 6); + assert!(buffer::get_string(&data, &mut pos).is_err()); + assert_eq!(pos, 6); + } }