From 2e4b4eb2a62a1bfde1601f005da5f218158ffc30 Mon Sep 17 00:00:00 2001 From: jabu Date: Mon, 11 Dec 2023 01:36:48 -0600 Subject: [PATCH] add test and some api improvements --- src/event.rs | 2 + src/host.rs | 13 +- src/lib.rs | 31 ++--- src/packet.rs | 16 ++- src/test.rs | 56 ++++++++ src/test/network.rs | 310 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 408 insertions(+), 20 deletions(-) create mode 100644 src/test.rs create mode 100644 src/test/network.rs diff --git a/src/event.rs b/src/event.rs index 7a0184c..1db18ec 100644 --- a/src/event.rs +++ b/src/event.rs @@ -54,6 +54,8 @@ impl<'a, S: Socket> Event<'a, S> { } /// An ENet event, like [`Event`], but without peer references. +/// +/// Acquired with [`Event::no_ref`]. #[derive(Debug, Clone)] pub enum EventNoRef { /// A new peer has connected. diff --git a/src/host.rs b/src/host.rs index ad1a1b7..4f82478 100644 --- a/src/host.rs +++ b/src/host.rs @@ -33,6 +33,8 @@ pub struct HostSettings { /// A custom time function to use, or [`None`] to use the default one. Should return an /// an accurate, incrementally increasing [`Duration`]. pub time: Option Duration>>, + /// Seed the host with a specific random seed, or set to [`None`] to use a random seed. + pub seed: Option, } impl Default for HostSettings { @@ -45,6 +47,7 @@ impl Default for HostSettings { compressor: None, checksum: None, time: None, + seed: None, } } } @@ -85,6 +88,13 @@ impl Host { settings.channel_limit, settings.incoming_bandwidth_limit.unwrap_or(0), settings.outgoing_bandwidth_limit.unwrap_or(0), + settings.time.unwrap_or(Box::new(|| { + use wasm_timer::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + })), + settings.seed, ); let mut peers = vec![]; for peer_index in 0..(*host).peerCount { @@ -96,9 +106,6 @@ impl Host { if let Some(checksum) = settings.checksum { *(*host).checksum.assume_init_mut() = Some(checksum); } - if let Some(time) = settings.time { - *(*host).time.assume_init_mut() = Some(time); - } if !host.is_null() { Ok(Self { host, peers }) } else { diff --git a/src/lib.rs b/src/lib.rs index b4673a6..f760869 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,6 +167,9 @@ pub use version::*; pub mod consts; use consts::*; +#[cfg(test)] +mod test; + /// A [`Result`](`core::result::Result`) type alias with this crate's [`Error`] type. pub type Result = core::result::Result; @@ -390,7 +393,7 @@ pub(crate) struct _ENetHost { pub(crate) buffers: [ENetBuffer; 65], pub(crate) bufferCount: size_t, pub(crate) checksum: MaybeUninit) -> u32>>>, - pub(crate) time: MaybeUninit Duration>>>, + pub(crate) time: MaybeUninit Duration>>, pub(crate) compressor: MaybeUninit>>, pub(crate) packetData: [[enet_uint8; 4096]; 2], pub(crate) receivedAddress: MaybeUninit>, @@ -1756,6 +1759,8 @@ pub(crate) unsafe fn enet_host_create( mut channelLimit: size_t, mut incomingBandwidth: enet_uint32, mut outgoingBandwidth: enet_uint32, + time: Box Duration>, + seed: Option, ) -> *mut ENetHost { let mut host: *mut ENetHost = 0 as *mut ENetHost; let mut currentPeer: *mut ENetPeer = 0 as *mut ENetPeer; @@ -1793,10 +1798,15 @@ pub(crate) unsafe fn enet_host_create( } else if channelLimit < ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT as c_int as size_t { channelLimit = ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT as c_int as size_t; } - (*host).randomSeed = host as size_t as enet_uint32; - (*host).randomSeed = ((*host).randomSeed as c_uint).wrapping_add(enet_time_get(host)) - as enet_uint32 as enet_uint32; - (*host).randomSeed = (*host).randomSeed << 16 as c_int | (*host).randomSeed >> 16 as c_int; + (*host).time.write(time); + if let Some(seed) = seed { + (*host).randomSeed = seed; + } else { + (*host).randomSeed = host as size_t as enet_uint32; + (*host).randomSeed = ((*host).randomSeed as c_uint).wrapping_add(enet_time_get(host)) + as enet_uint32 as enet_uint32; + (*host).randomSeed = (*host).randomSeed << 16 as c_int | (*host).randomSeed >> 16 as c_int; + } (*host).channelLimit = channelLimit; (*host).incomingBandwidth = incomingBandwidth; (*host).outgoingBandwidth = outgoingBandwidth; @@ -1807,7 +1817,6 @@ pub(crate) unsafe fn enet_host_create( (*host).commandCount = 0 as c_int as size_t; (*host).bufferCount = 0 as c_int as size_t; (*host).checksum.write(None); - (*host).time.write(None); (*host).receivedAddress.write(None); (*host).receivedData = 0 as *mut enet_uint8; (*host).receivedDataLength = 0 as c_int as size_t; @@ -6739,13 +6748,5 @@ pub(crate) unsafe fn enet_host_random_seed(host: *mut ENetHost) -> enet_time_get(host) } pub(crate) unsafe fn enet_time_get(host: *mut ENetHost) -> enet_uint32 { - let duration = if let Some(time) = (*host).time.assume_init_ref() { - time() - } else { - use wasm_timer::{SystemTime, UNIX_EPOCH}; - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - }; - (duration.as_millis() % u32::MAX as u128) as enet_uint32 + ((*host).time.assume_init_ref()().as_millis() % u32::MAX as u128) as enet_uint32 } diff --git a/src/packet.rs b/src/packet.rs index f59e384..e5aa148 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,4 +1,4 @@ -use core::slice; +use std::{fmt::Debug, slice}; use crate::{ c_void, enet_packet_create, enet_packet_destroy, size_t, ENetPacket, ENET_PACKET_FLAG_RELIABLE, @@ -37,7 +37,6 @@ pub enum PacketKind { /// See [`Fragmentation and Reassembly`](`crate#fragmentation-and-reassembly`). /// /// For more information on the kinds of ENet packets, see [`PacketKind`]. -#[derive(Debug)] pub struct Packet { pub(crate) packet: *mut ENetPacket, } @@ -154,3 +153,16 @@ impl Drop for Packet { } } } + +impl Debug for Packet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let packet = unsafe { &(*self.packet) }; + f.debug_struct("Packet") + .field("data", &packet.data) + .field("dataLength", &packet.dataLength) + .field("userData", &packet.userData) + .field("flags", &packet.flags) + .field("kind", &self.kind()) + .finish() + } +} diff --git a/src/test.rs b/src/test.rs new file mode 100644 index 0000000..4b1fd4b --- /dev/null +++ b/src/test.rs @@ -0,0 +1,56 @@ +use std::time::Duration; + +use crate as enet; + +mod network; +use network::*; + +#[test] +fn events() { + let mut network = Network::new(); + let mut host1 = network.create_host(enet::HostSettings { + peer_limit: 1, + ..Default::default() + }); + let mut host2 = network.create_host(enet::HostSettings { + peer_limit: 1, + ..Default::default() + }); + + network.connect(host1, host2, 255, 5); + network.update(Duration::from_millis(10)); + let events = network.update(Duration::from_millis(10)); + assert_eq!(events.len(), 2); + assert!(events[0].is_connect_and(|event| event.to == host1 + && event.from == host2 + && event.peer == enet::PeerID(0) + && event.data == 0)); + assert!(events[1].is_connect_and(|event| event.to == host2 + && event.from == host1 + && event.peer == enet::PeerID(0) + && event.data == 5)); + + network.send( + host1, + host2, + 0, + enet::Packet::reliable("hello world".as_bytes()), + ); + let events = network.update(Duration::from_millis(10)); + assert_eq!(events.len(), 1); + assert!(events[0].is_receive_and(|event| event.from == host1 + && event.to == host2 + && event.channel_id == 0 + && event.packet.data().len() == 11 + && event.packet.kind() == enet::PacketKind::Reliable)); + + network.disconnect(host1, host2, 10); + let events = network.update(Duration::from_millis(10)); + assert_eq!(events.len(), 1); + assert!(events[0] + .is_disconnect_and(|event| event.from == host1 && event.to == host2 && event.data == 10)); + let events = network.update(Duration::from_millis(10)); + assert_eq!(events.len(), 1); + assert!(events[0] + .is_disconnect_and(|event| event.from == host2 && event.to == host1 && event.data == 0)); +} diff --git a/src/test/network.rs b/src/test/network.rs new file mode 100644 index 0000000..39524ce --- /dev/null +++ b/src/test/network.rs @@ -0,0 +1,310 @@ +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, + sync::{mpsc, Arc, RwLock}, + time::Duration, +}; + +use crate as enet; + +pub struct Socket { + sender: mpsc::Sender<(usize, Vec)>, + receiver: mpsc::Receiver<(usize, Vec)>, +} + +impl Socket { + fn connect() -> (Socket, Socket) { + let (sender1, receiver2) = mpsc::channel(); + let (sender2, receiver1) = mpsc::channel(); + ( + Socket { + sender: sender1, + receiver: receiver1, + }, + Socket { + sender: sender2, + receiver: receiver2, + }, + ) + } + + fn send(&mut self, address: usize, data: &[u8]) { + self.sender.send((address, data.to_vec())).unwrap(); + } + + fn receive(&mut self) -> Option<(usize, Vec)> { + match self.receiver.recv_timeout(Duration::ZERO) { + Ok((address, data)) => Some((address, data)), + Err(mpsc::RecvTimeoutError::Timeout) => None, + Err(mpsc::RecvTimeoutError::Disconnected) => unreachable!(), + } + } +} + +impl enet::Socket for Socket { + type PeerAddress = usize; + type Error = enet::Error; + + fn init(&mut self, _socket_options: enet::SocketOptions) -> Result<(), Self::Error> { + Ok(()) + } + + fn send(&mut self, address: Self::PeerAddress, buffer: &[u8]) -> Result { + Socket::send(self, address, buffer); + Ok(buffer.len()) + } + + fn receive( + &mut self, + _mtu: usize, + ) -> Result, Self::Error> { + if let Some((address, data)) = Socket::receive(self) { + Ok(Some((address, enet::PacketReceived::Complete(data)))) + } else { + Ok(None) + } + } +} + +impl enet::Address for usize { + fn same_host(&self, other: &usize) -> bool { + *self == *other + } + + fn same(&self, other: &usize) -> bool { + *self == *other + } + + fn is_broadcast(&self) -> bool { + false + } +} + +#[derive(Default)] +pub struct Network { + sockets: Vec, + hosts: Vec>, + connections: HashMap<(usize, usize), enet::PeerID>, + time: Arc>, +} + +impl Network { + pub fn new() -> Self { + Self::default() + } + + fn send_and_receive(&mut self) { + let mut events = vec![]; + for (from, socket) in self.sockets.iter_mut().enumerate() { + while let Some(event) = socket.receive() { + events.push((from, event.0, event.1)); + } + } + for (from, to, data) in events { + self.sockets[to].send(from, &data); + } + } + + pub fn update(&mut self, time: Duration) -> Vec { + macro_rules! send_and_receive { + () => { + let mut events = vec![]; + for (from, socket) in self.sockets.iter_mut().enumerate() { + while let Some(event) = socket.receive() { + events.push((from, event.0, event.1)); + } + } + for (from, to, data) in events { + self.sockets[to].send(from, &data); + } + }; + } + let mut events = vec![]; + for (host_index, host) in self.hosts.iter_mut().enumerate() { + send_and_receive!(); + while let Some(event) = host.service().unwrap() { + let peer_index: usize; + match &event { + enet::Event::Connect { peer, .. } => { + peer_index = peer.address().unwrap(); + self.connections.insert((host_index, peer_index), peer.id()); + } + enet::Event::Disconnect { peer, .. } => { + peer_index = peer.address().unwrap(); + self.connections.remove(&(host_index, peer_index)); + } + enet::Event::Receive { peer, .. } => { + peer_index = peer.address().unwrap(); + } + } + events.push(Event { + from: peer_index, + to: host_index, + event: event.no_ref(), + }); + send_and_receive!(); + } + } + *self.time.write().unwrap() += (time.as_millis() % u32::MAX as u128) as u32; + events + } + + pub fn create_host(&mut self, mut settings: enet::HostSettings) -> usize { + let index = self.hosts.len(); + let time = self.time.clone(); + settings.time = Some(Box::new(move || { + Duration::from_millis(*time.read().unwrap() as u64) + })); + settings.seed = Some(0); + let (network_socket, host_socket) = Socket::connect(); + self.sockets.push(network_socket); + self.hosts + .push(enet::Host::create(host_socket, settings).unwrap()); + index + } + + pub fn resolve_peer(&self, from: usize, to: usize) -> enet::PeerID { + self.connections[&(from, to)] + } + + pub fn connect(&mut self, from: usize, to: usize, channel_count: usize, data: u32) { + self.hosts[from].connect(to, channel_count, data).unwrap(); + } + + pub fn disconnect(&mut self, from: usize, to: usize, data: u32) { + let peer = self.resolve_peer(from, to); + self.hosts[from].peer_mut(peer).disconnect(data) + } + + pub fn send(&mut self, from: usize, to: usize, channel_id: u8, packet: enet::Packet) { + let peer = self.resolve_peer(from, to); + self.hosts[from] + .peer_mut(peer) + .send(channel_id, packet) + .unwrap(); + } +} + +pub struct Host { + host: enet::Host, + address: usize, +} + +impl Host { + pub fn address(&self) -> usize { + self.address + } +} + +impl Deref for Host { + type Target = enet::Host; + + fn deref(&self) -> &Self::Target { + &self.host + } +} + +impl DerefMut for Host { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.host + } +} + +#[derive(Clone)] +pub struct Event { + from: usize, + to: usize, + event: enet::EventNoRef, +} + +#[derive(Debug, Clone)] +pub struct EventConnect { + pub from: usize, + pub to: usize, + pub peer: enet::PeerID, + pub data: u32, +} + +#[derive(Debug, Clone)] +pub struct EventDisconnect { + pub from: usize, + pub to: usize, + pub peer: enet::PeerID, + pub data: u32, +} + +#[derive(Debug, Clone)] +pub struct EventReceive { + pub from: usize, + pub to: usize, + pub peer: enet::PeerID, + pub channel_id: u8, + pub packet: enet::Packet, +} + +impl Event { + pub fn from(&self) -> usize { + self.from + } + + pub fn to(&self) -> usize { + self.to + } + + pub fn is_connect(&self) -> bool { + matches!(&self.event, enet::EventNoRef::Connect { .. }) + } + + pub fn is_connect_and(&self, and: impl Fn(EventConnect) -> bool) -> bool { + if let enet::EventNoRef::Connect { peer, data } = &self.event { + and(EventConnect { + from: self.from, + to: self.to, + peer: *peer, + data: *data, + }) + } else { + false + } + } + + pub fn is_disconnect(&self) -> bool { + matches!(&self.event, enet::EventNoRef::Disconnect { .. }) + } + + pub fn is_disconnect_and(&self, and: impl Fn(EventDisconnect) -> bool) -> bool { + if let enet::EventNoRef::Disconnect { peer, data } = &self.event { + and(EventDisconnect { + from: self.from, + to: self.to, + peer: *peer, + data: *data, + }) + } else { + false + } + } + + pub fn is_receive(&self) -> bool { + matches!(&self.event, enet::EventNoRef::Receive { .. }) + } + + pub fn is_receive_and(&self, and: impl Fn(EventReceive) -> bool) -> bool { + if let enet::EventNoRef::Receive { + peer, + channel_id, + packet, + } = &self.event + { + and(EventReceive { + from: self.from, + to: self.to, + peer: *peer, + channel_id: *channel_id, + packet: packet.clone(), + }) + } else { + false + } + } +}